Created using Colaboratory

This commit is contained in:
udlbook
2023-11-23 11:08:41 +00:00
parent b47c4a210a
commit 5c10091b48

View File

@@ -4,7 +4,7 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"authorship_tag": "ABX9TyOoGS+lY+EhGthebSO4smpj", "authorship_tag": "ABX9TyOZaNcBrdZ9yCHhjLOwSi69",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@@ -205,7 +205,8 @@
" self.linear3 = nn.Linear(hidden_size, hidden_size)\n", " self.linear3 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear4 = nn.Linear(hidden_size, hidden_size)\n", " self.linear4 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear5 = nn.Linear(hidden_size, hidden_size)\n", " self.linear5 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear6 = nn.Linear(hidden_size, output_size)\n", " self.linear6 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear7 = nn.Linear(hidden_size, output_size)\n",
"\n", "\n",
" def count_params(self):\n", " def count_params(self):\n",
" return sum([p.view(-1).shape[0] for p in self.parameters()])\n", " return sum([p.view(-1).shape[0] for p in self.parameters()])\n",
@@ -220,11 +221,11 @@
" print_variance(\"After second residual connection\",res2)\n", " print_variance(\"After second residual connection\",res2)\n",
" res3 = res2 + self.linear4(res2.relu())\n", " res3 = res2 + self.linear4(res2.relu())\n",
" print_variance(\"After third residual connection\",res3)\n", " print_variance(\"After third residual connection\",res3)\n",
" res4 = res3 + self.linear4(res3.relu())\n", " res4 = res3 + self.linear5(res3.relu())\n",
" print_variance(\"After fourth residual connection\",res4)\n", " print_variance(\"After fourth residual connection\",res4)\n",
" res5 = res4 + self.linear4(res4.relu())\n", " res5 = res4 + self.linear6(res4.relu())\n",
" print_variance(\"After fifth residual connection\",res5)\n", " print_variance(\"After fifth residual connection\",res5)\n",
" return self.linear6(res5)" " return self.linear7(res5)"
], ],
"metadata": { "metadata": {
"id": "FslroPJJffrh" "id": "FslroPJJffrh"
@@ -266,13 +267,14 @@
"# Use the torch function nn.BatchNorm1d\n", "# Use the torch function nn.BatchNorm1d\n",
"class ResidualNetworkWithBatchNorm(torch.nn.Module):\n", "class ResidualNetworkWithBatchNorm(torch.nn.Module):\n",
" def __init__(self, input_size, output_size, hidden_size=100):\n", " def __init__(self, input_size, output_size, hidden_size=100):\n",
" super(ResidualNetworkWithBatchNorm, self).__init__()\n", " super(ResidualNetwork, self).__init__()\n",
" self.linear1 = nn.Linear(input_size, hidden_size)\n", " self.linear1 = nn.Linear(input_size, hidden_size)\n",
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n", " self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear3 = nn.Linear(hidden_size, hidden_size)\n", " self.linear3 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear4 = nn.Linear(hidden_size, hidden_size)\n", " self.linear4 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear5 = nn.Linear(hidden_size, hidden_size)\n", " self.linear5 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear6 = nn.Linear(hidden_size, output_size)\n", " self.linear6 = nn.Linear(hidden_size, hidden_size)\n",
" self.linear7 = nn.Linear(hidden_size, output_size)\n",
"\n", "\n",
" def count_params(self):\n", " def count_params(self):\n",
" return sum([p.view(-1).shape[0] for p in self.parameters()])\n", " return sum([p.view(-1).shape[0] for p in self.parameters()])\n",
@@ -287,11 +289,11 @@
" print_variance(\"After second residual connection\",res2)\n", " print_variance(\"After second residual connection\",res2)\n",
" res3 = res2 + self.linear4(res2.relu())\n", " res3 = res2 + self.linear4(res2.relu())\n",
" print_variance(\"After third residual connection\",res3)\n", " print_variance(\"After third residual connection\",res3)\n",
" res4 = res3 + self.linear4(res3.relu())\n", " res4 = res3 + self.linear5(res3.relu())\n",
" print_variance(\"After fourth residual connection\",res4)\n", " print_variance(\"After fourth residual connection\",res4)\n",
" res5 = res4 + self.linear4(res4.relu())\n", " res5 = res4 + self.linear6(res4.relu())\n",
" print_variance(\"After fifth residual connection\",res5)\n", " print_variance(\"After fifth residual connection\",res5)\n",
" return self.linear6(res5)" " return self.linear7(res5)"
], ],
"metadata": { "metadata": {
"id": "5JvMmaRITKGd" "id": "5JvMmaRITKGd"