Files
udlbook/Notebooks/Chap20/20_1_Random_Data.ipynb
Mark Gotham 9649ce382b "TO DO" > "TODO
In [commit 6072ad4](6072ad4), @KajvanRijn kindly changed all "TO DO" to "TODO" in the code blocks. That's useful. In addition, it should be changed (as here) in the instructions. Then there's no doubt or issue for anyone searching all instances.
2025-02-11 15:11:06 +00:00

305 lines
10 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNgBRvfIlngVobKuLE6leM+",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap20/20_1_Random_Data.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Notebook 20.1: Random Data**\n",
"\n",
"This notebook investigates training the network with random data, as illustrated in figure 20.1.\n",
"\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n",
"\n"
],
"metadata": {
"id": "t9vk9Elugvmi"
}
},
{
"cell_type": "code",
"source": [
"# Run this if you're in a Colab to install MNIST 1D repository\n",
"!pip install git+https://github.com/greydanus/mnist1d"
],
"metadata": {
"id": "D5yLObtZCi9J"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import os\n",
"import torch, torch.nn as nn\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"from torch.optim.lr_scheduler import StepLR\n",
"import matplotlib.pyplot as plt\n",
"import mnist1d\n",
"import random\n",
"from IPython.display import display, clear_output"
],
"metadata": {
"id": "YrXWAH7sUWvU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"args = mnist1d.data.get_dataset_args()\n",
"data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)\n",
"\n",
"# The training and test input and outputs are in\n",
"# data['x'], data['y'], data['x_test'], and data['y_test']\n",
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
"print(\"Length of each example: {}\".format(data['x'].shape[-1]))"
],
"metadata": {
"id": "twI72ZCrCt5z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Define the network"
],
"metadata": {
"id": "_sFvRDGrl4qe"
}
},
{
"cell_type": "code",
"source": [
"D_i = 40 # Input dimensions\n",
"D_k = 300 # Hidden dimensions\n",
"D_o = 10 # Output dimensions\n",
"\n",
"model = nn.Sequential(\n",
"nn.Linear(D_i, D_k),\n",
"nn.ReLU(),\n",
"nn.Linear(D_k, D_k),\n",
"nn.ReLU(),\n",
"nn.Linear(D_k, D_k),\n",
"nn.ReLU(),\n",
"nn.Linear(D_k, D_k),\n",
"nn.ReLU(),\n",
"nn.Linear(D_k, D_o))"
],
"metadata": {
"id": "FslroPJJffrh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# He initialization of weights\n",
"def weights_init(layer_in):\n",
" if isinstance(layer_in, nn.Linear):\n",
" nn.init.kaiming_uniform_(layer_in.weight)\n",
" layer_in.bias.data.fill_(0.0)"
],
"metadata": {
"id": "YgLaex1pfhqz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train_model(train_data_x, train_data_y, n_epoch):\n",
" # choose cross entropy loss function (equation 5.24 in the loss notes)\n",
" loss_function = nn.CrossEntropyLoss()\n",
" # construct SGD optimizer and initialize learning rate and momentum\n",
" optimizer = torch.optim.SGD(model.parameters(), lr = 0.02, momentum=0.9)\n",
" # object that decreases learning rate by half every 20 epochs\n",
" scheduler = StepLR(optimizer, step_size=20, gamma=0.5)\n",
" # create 100 dummy data points and store in data loader class\n",
" x_train = torch.tensor(train_data_x.transpose().astype('float32'))\n",
" y_train = torch.tensor(train_data_y.astype('long'))\n",
"\n",
" # load the data into a class that creates the batches\n",
" data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
"\n",
" # Initialize model weights\n",
" model.apply(weights_init)\n",
"\n",
" # store the loss and the % correct at each epoch\n",
" losses_train = np.zeros((n_epoch))\n",
"\n",
" for epoch in range(n_epoch):\n",
" # loop over batches\n",
" for i, data in enumerate(data_loader):\n",
" # retrieve inputs and labels for this batch\n",
" x_batch, y_batch = data\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
" # forward pass -- calculate model output\n",
" pred = model(x_batch)\n",
" # compute the loss\n",
" loss = loss_function(pred, y_batch)\n",
" # backward pass\n",
" loss.backward()\n",
" # SGD update\n",
" optimizer.step()\n",
"\n",
" # Run whole dataset to get statistics -- normally wouldn't do this\n",
" pred_train = model(x_train)\n",
" _, predicted_train_class = torch.max(pred_train.data, 1)\n",
" losses_train[epoch] = loss_function(pred_train, y_train).item()\n",
" if epoch % 5 == 0:\n",
" clear_output(wait=True)\n",
" display(\"Epoch %d, train loss %3.3f\"%(epoch, losses_train[epoch]))\n",
"\n",
" # tell scheduler to consider updating learning rate\n",
" scheduler.step()\n",
"\n",
" return losses_train"
],
"metadata": {
"id": "NYw8I_3mmX5c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load in the data\n",
"train_data_x = data['x'].transpose()\n",
"train_data_y = data['y']\n",
"# Print out sizes\n",
"print(\"Train data: %d examples (columns), each of which has %d dimensions (rows)\"%((train_data_x.shape[1],train_data_x.shape[0])))"
],
"metadata": {
"id": "4FE3HQ_vedXO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute loss for proper data and plot\n",
"n_epoch = 60\n",
"loss_true_labels = train_model(train_data_x, train_data_y, n_epoch)\n",
"# Plot the results\n",
"fig, ax = plt.subplots()\n",
"ax.plot(loss_true_labels,'r-',label='true_labels')\n",
"# ax.set_ylim(0,0.7); ax.set_xlim(0,n_epoch)\n",
"ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')\n",
"ax.legend()\n",
"plt.show()"
],
"metadata": {
"id": "b56wdODqemF1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# TODO -- Randomize the input data (train_data_x), but retain overall mean and variance\n",
"# Replace this line\n",
"train_data_x_randomized = np.copy(train_data_x)"
],
"metadata": {
"id": "SbPCiiUKgTLw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute loss for true labels and plot\n",
"n_epoch = 60\n",
"loss_randomized_data = train_model(train_data_x_randomized, train_data_y, n_epoch)\n",
"# Plot the results\n",
"fig, ax = plt.subplots()\n",
"ax.plot(loss_true_labels,'r-',label='true_labels')\n",
"ax.plot(loss_randomized_data,'b-',label='random_data')\n",
"# ax.set_ylim(0,0.7); ax.set_xlim(0,n_epoch)\n",
"ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')\n",
"ax.legend()\n",
"plt.show()"
],
"metadata": {
"id": "y7CcCJvvjLnn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# TODO -- Permute the labels\n",
"# Replace this line:\n",
"train_data_y_permuted = np.copy(train_data_y)"
],
"metadata": {
"id": "ojaMTrzKj_74"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute loss for true labels and plot\n",
"n_epoch = 60\n",
"loss_permuted_labels = train_model(train_data_x, train_data_y_permuted, n_epoch)\n",
"# Plot the results\n",
"fig, ax = plt.subplots()\n",
"ax.plot(loss_true_labels,'r-',label='true_labels')\n",
"ax.plot(loss_randomized_data,'b-',label='random_data')\n",
"ax.plot(loss_permuted_labels,'g-',label='random_labels')\n",
"# ax.set_ylim(0,0.7); ax.set_xlim(0,n_epoch)\n",
"ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')\n",
"ax.legend()\n",
"plt.show()"
],
"metadata": {
"id": "LaYCSjyMo9LQ"
},
"execution_count": null,
"outputs": []
}
]
}