In [commit6072ad4](6072ad4), @KajvanRijn kindly changed all "TO DO" to "TODO" in the code blocks. That's useful. In addition, it should be changed (as here) in the instructions. Then there's no doubt or issue for anyone searching all instances.
330 lines
12 KiB
Plaintext
330 lines
12 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": [],
|
|
"authorship_tag": "ABX9TyPx2mM2zTHmDJeKeiE1RymT",
|
|
"include_colab_link": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap11/11_3_Batch_Normalization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# **Notebook 11.3: Batch normalization**\n",
|
|
"\n",
|
|
"This notebook investigates the use of batch normalization in residual networks.\n",
|
|
"\n",
|
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
|
"\n",
|
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n",
|
|
"\n"
|
|
],
|
|
"metadata": {
|
|
"id": "t9vk9Elugvmi"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Run this if you're in a Colab to install MNIST 1D repository\n",
|
|
"!pip install git+https://github.com/greydanus/mnist1d"
|
|
],
|
|
"metadata": {
|
|
"id": "D5yLObtZCi9J"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"import torch, torch.nn as nn\n",
|
|
"from torch.utils.data import TensorDataset, DataLoader\n",
|
|
"from torch.optim.lr_scheduler import StepLR\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import mnist1d\n",
|
|
"import random"
|
|
],
|
|
"metadata": {
|
|
"id": "YrXWAH7sUWvU"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"args = mnist1d.data.get_dataset_args()\n",
|
|
"data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)\n",
|
|
"\n",
|
|
"# The training and test input and outputs are in\n",
|
|
"# data['x'], data['y'], data['x_test'], and data['y_test']\n",
|
|
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
|
|
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
|
|
"print(\"Length of each example: {}\".format(data['x'].shape[-1]))"
|
|
],
|
|
"metadata": {
|
|
"id": "twI72ZCrCt5z"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Load in the data\n",
|
|
"train_data_x = data['x'].transpose()\n",
|
|
"train_data_y = data['y']\n",
|
|
"val_data_x = data['x_test'].transpose()\n",
|
|
"val_data_y = data['y_test']\n",
|
|
"# Print out sizes\n",
|
|
"print(\"Train data: %d examples (columns), each of which has %d dimensions (rows)\"%((train_data_x.shape[1],train_data_x.shape[0])))\n",
|
|
"print(\"Validation data: %d examples (columns), each of which has %d dimensions (rows)\"%((val_data_x.shape[1],val_data_x.shape[0])))"
|
|
],
|
|
"metadata": {
|
|
"id": "8bKADvLHbiV5"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"def print_variance(name, data):\n",
|
|
" # First dimension(rows) is batch elements\n",
|
|
" # Second dimension(columns) is neurons.\n",
|
|
" np_data = data.detach().numpy()\n",
|
|
" # Compute variance across neurons and average these variances over members of the batch\n",
|
|
" neuron_variance = np.mean(np.var(np_data, axis=0))\n",
|
|
" # Print out the name and the variance\n",
|
|
" print(\"%s variance=%f\"%(name,neuron_variance))"
|
|
],
|
|
"metadata": {
|
|
"id": "3bBpJIV-N-lt"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# He initialization of weights\n",
|
|
"def weights_init(layer_in):\n",
|
|
" if isinstance(layer_in, nn.Linear):\n",
|
|
" nn.init.kaiming_uniform_(layer_in.weight)\n",
|
|
" layer_in.bias.data.fill_(0.0)"
|
|
],
|
|
"metadata": {
|
|
"id": "YgLaex1pfhqz"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"def run_one_step_of_model(model, x_train, y_train):\n",
|
|
" # choose cross entropy loss function (equation 5.24 in the loss notes)\n",
|
|
" loss_function = nn.CrossEntropyLoss()\n",
|
|
" # construct SGD optimizer and initialize learning rate and momentum\n",
|
|
" optimizer = torch.optim.SGD(model.parameters(), lr = 0.05, momentum=0.9)\n",
|
|
"\n",
|
|
" # load the data into a class that creates the batches\n",
|
|
" data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=200, shuffle=True, worker_init_fn=np.random.seed(1))\n",
|
|
"\n",
|
|
" # Initialize model weights\n",
|
|
" model.apply(weights_init)\n",
|
|
"\n",
|
|
" # Get a batch\n",
|
|
" for i, data in enumerate(data_loader):\n",
|
|
" # retrieve inputs and labels for this batch\n",
|
|
" x_batch, y_batch = data\n",
|
|
" # zero the parameter gradients\n",
|
|
" optimizer.zero_grad()\n",
|
|
" # forward pass -- calculate model output\n",
|
|
" pred = model(x_batch)\n",
|
|
" # compute the loss\n",
|
|
" loss = loss_function(pred, y_batch)\n",
|
|
" # backward pass\n",
|
|
" loss.backward()\n",
|
|
" # SGD update\n",
|
|
" optimizer.step()\n",
|
|
" # Break out of this loop -- we just want to see the first\n",
|
|
" # iteration, but usually we would continue\n",
|
|
" break"
|
|
],
|
|
"metadata": {
|
|
"id": "DFlu45pORQEz"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# convert training data to torch tensors\n",
|
|
"x_train = torch.tensor(train_data_x.transpose().astype('float32'))\n",
|
|
"y_train = torch.tensor(train_data_y.astype('long'))"
|
|
],
|
|
"metadata": {
|
|
"id": "i7Q0ScWgRe4G"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# This is a simple residual model with 5 residual branches in a row\n",
|
|
"class ResidualNetwork(torch.nn.Module):\n",
|
|
" def __init__(self, input_size, output_size, hidden_size=100):\n",
|
|
" super(ResidualNetwork, self).__init__()\n",
|
|
" self.linear1 = nn.Linear(input_size, hidden_size)\n",
|
|
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear3 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear4 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear5 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear6 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear7 = nn.Linear(hidden_size, output_size)\n",
|
|
"\n",
|
|
" def count_params(self):\n",
|
|
" return sum([p.view(-1).shape[0] for p in self.parameters()])\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" print_variance(\"Input\",x)\n",
|
|
" f = self.linear1(x)\n",
|
|
" print_variance(\"First preactivation\",f)\n",
|
|
" res1 = f+ self.linear2(f.relu())\n",
|
|
" print_variance(\"After first residual connection\",res1)\n",
|
|
" res2 = res1 + self.linear3(res1.relu())\n",
|
|
" print_variance(\"After second residual connection\",res2)\n",
|
|
" res3 = res2 + self.linear4(res2.relu())\n",
|
|
" print_variance(\"After third residual connection\",res3)\n",
|
|
" res4 = res3 + self.linear5(res3.relu())\n",
|
|
" print_variance(\"After fourth residual connection\",res4)\n",
|
|
" res5 = res4 + self.linear6(res4.relu())\n",
|
|
" print_variance(\"After fifth residual connection\",res5)\n",
|
|
" return self.linear7(res5)"
|
|
],
|
|
"metadata": {
|
|
"id": "FslroPJJffrh"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Define the model and run for one step\n",
|
|
"# Monitoring the variance at each point in the network\n",
|
|
"n_hidden = 100\n",
|
|
"n_input = 40\n",
|
|
"n_output = 10\n",
|
|
"model = ResidualNetwork(n_input, n_output, n_hidden)\n",
|
|
"run_one_step_of_model(model, x_train, y_train)"
|
|
],
|
|
"metadata": {
|
|
"id": "NYw8I_3mmX5c"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"Notice that the variance roughly doubles at each step so it increases exponentially as in figure 11.6b in the book."
|
|
],
|
|
"metadata": {
|
|
"id": "0kZUlWkkW8jE"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# TODO Adapt the residual network below to add a batch norm operation\n",
|
|
"# before the contents of each residual link as in figure 11.6c in the book\n",
|
|
"# Use the torch function nn.BatchNorm1d\n",
|
|
"class ResidualNetworkWithBatchNorm(torch.nn.Module):\n",
|
|
" def __init__(self, input_size, output_size, hidden_size=100):\n",
|
|
" super(ResidualNetworkWithBatchNorm, self).__init__()\n",
|
|
" self.linear1 = nn.Linear(input_size, hidden_size)\n",
|
|
" self.linear2 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear3 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear4 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear5 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear6 = nn.Linear(hidden_size, hidden_size)\n",
|
|
" self.linear7 = nn.Linear(hidden_size, output_size)\n",
|
|
"\n",
|
|
" def count_params(self):\n",
|
|
" return sum([p.view(-1).shape[0] for p in self.parameters()])\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" print_variance(\"Input\",x)\n",
|
|
" f = self.linear1(x)\n",
|
|
" print_variance(\"First preactivation\",f)\n",
|
|
" res1 = f+ self.linear2(f.relu())\n",
|
|
" print_variance(\"After first residual connection\",res1)\n",
|
|
" res2 = res1 + self.linear3(res1.relu())\n",
|
|
" print_variance(\"After second residual connection\",res2)\n",
|
|
" res3 = res2 + self.linear4(res2.relu())\n",
|
|
" print_variance(\"After third residual connection\",res3)\n",
|
|
" res4 = res3 + self.linear5(res3.relu())\n",
|
|
" print_variance(\"After fourth residual connection\",res4)\n",
|
|
" res5 = res4 + self.linear6(res4.relu())\n",
|
|
" print_variance(\"After fifth residual connection\",res5)\n",
|
|
" return self.linear7(res5)"
|
|
],
|
|
"metadata": {
|
|
"id": "5JvMmaRITKGd"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Define the model\n",
|
|
"n_hidden = 100\n",
|
|
"n_input = 40\n",
|
|
"n_output = 10\n",
|
|
"model = ResidualNetworkWithBatchNorm(n_input, n_output, n_hidden)\n",
|
|
"run_one_step_of_model(model, x_train, y_train)"
|
|
],
|
|
"metadata": {
|
|
"id": "2U3DnlH9Uw6c"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"Note that the variance now increases linearly as in figure 11.6c."
|
|
],
|
|
"metadata": {
|
|
"id": "R_ucFq9CXq8D"
|
|
}
|
|
}
|
|
]
|
|
} |