diff --git a/Notebooks/Chap11/11_3_Batch_Normalization.ipynb b/Notebooks/Chap11/11_3_Batch_Normalization.ipynb index 625f8fa..a085ecf 100644 --- a/Notebooks/Chap11/11_3_Batch_Normalization.ipynb +++ b/Notebooks/Chap11/11_3_Batch_Normalization.ipynb @@ -4,7 +4,7 @@ "metadata": { "colab": { "provenance": [], - "authorship_tag": "ABX9TyOoGS+lY+EhGthebSO4smpj", + "authorship_tag": "ABX9TyOZaNcBrdZ9yCHhjLOwSi69", "include_colab_link": true }, "kernelspec": { @@ -205,7 +205,8 @@ " self.linear3 = nn.Linear(hidden_size, hidden_size)\n", " self.linear4 = nn.Linear(hidden_size, hidden_size)\n", " self.linear5 = nn.Linear(hidden_size, hidden_size)\n", - " self.linear6 = nn.Linear(hidden_size, output_size)\n", + " self.linear6 = nn.Linear(hidden_size, hidden_size)\n", + " self.linear7 = nn.Linear(hidden_size, output_size)\n", "\n", " def count_params(self):\n", " return sum([p.view(-1).shape[0] for p in self.parameters()])\n", @@ -220,11 +221,11 @@ " print_variance(\"After second residual connection\",res2)\n", " res3 = res2 + self.linear4(res2.relu())\n", " print_variance(\"After third residual connection\",res3)\n", - " res4 = res3 + self.linear4(res3.relu())\n", + " res4 = res3 + self.linear5(res3.relu())\n", " print_variance(\"After fourth residual connection\",res4)\n", - " res5 = res4 + self.linear4(res4.relu())\n", + " res5 = res4 + self.linear6(res4.relu())\n", " print_variance(\"After fifth residual connection\",res5)\n", - " return self.linear6(res5)" + " return self.linear7(res5)" ], "metadata": { "id": "FslroPJJffrh" @@ -266,13 +267,14 @@ "# Use the torch function nn.BatchNorm1d\n", "class ResidualNetworkWithBatchNorm(torch.nn.Module):\n", " def __init__(self, input_size, output_size, hidden_size=100):\n", - " super(ResidualNetworkWithBatchNorm, self).__init__()\n", + " super(ResidualNetwork, self).__init__()\n", " self.linear1 = nn.Linear(input_size, hidden_size)\n", " self.linear2 = nn.Linear(hidden_size, hidden_size)\n", " self.linear3 = nn.Linear(hidden_size, hidden_size)\n", " self.linear4 = nn.Linear(hidden_size, hidden_size)\n", " self.linear5 = nn.Linear(hidden_size, hidden_size)\n", - " self.linear6 = nn.Linear(hidden_size, output_size)\n", + " self.linear6 = nn.Linear(hidden_size, hidden_size)\n", + " self.linear7 = nn.Linear(hidden_size, output_size)\n", "\n", " def count_params(self):\n", " return sum([p.view(-1).shape[0] for p in self.parameters()])\n", @@ -287,11 +289,11 @@ " print_variance(\"After second residual connection\",res2)\n", " res3 = res2 + self.linear4(res2.relu())\n", " print_variance(\"After third residual connection\",res3)\n", - " res4 = res3 + self.linear4(res3.relu())\n", + " res4 = res3 + self.linear5(res3.relu())\n", " print_variance(\"After fourth residual connection\",res4)\n", - " res5 = res4 + self.linear4(res4.relu())\n", + " res5 = res4 + self.linear6(res4.relu())\n", " print_variance(\"After fifth residual connection\",res5)\n", - " return self.linear6(res5)" + " return self.linear7(res5)" ], "metadata": { "id": "5JvMmaRITKGd"