Created using Colaboratory
This commit is contained in:
305
Notebooks/Chap20/20_1_Random_Data.ipynb
Normal file
305
Notebooks/Chap20/20_1_Random_Data.ipynb
Normal file
@@ -0,0 +1,305 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"authorship_tag": "ABX9TyPkSYbEjOcEmLt8tU6HxNuR",
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap20/20_1_Random_Data.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# **Notebook 20.1: Random Data**\n",
|
||||
"\n",
|
||||
"This notebook investigates training the network with random data, as illustrated in figure 20.1.\n",
|
||||
"\n",
|
||||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
||||
"\n",
|
||||
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "t9vk9Elugvmi"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Run this if you're in a Colab to make a local copy of the MNIST 1D repository\n",
|
||||
"!git clone https://github.com/greydanus/mnist1d"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "D5yLObtZCi9J"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import os\n",
|
||||
"import torch, torch.nn as nn\n",
|
||||
"from torch.utils.data import TensorDataset, DataLoader\n",
|
||||
"from torch.optim.lr_scheduler import StepLR\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import mnist1d\n",
|
||||
"import random\n",
|
||||
"from IPython.display import display, clear_output"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YrXWAH7sUWvU"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"args = mnist1d.data.get_dataset_args()\n",
|
||||
"data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)\n",
|
||||
"\n",
|
||||
"# The training and test input and outputs are in\n",
|
||||
"# data['x'], data['y'], data['x_test'], and data['y_test']\n",
|
||||
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
|
||||
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
|
||||
"print(\"Length of each example: {}\".format(data['x'].shape[-1]))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "twI72ZCrCt5z"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Define the network"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_sFvRDGrl4qe"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"D_i = 40 # Input dimensions\n",
|
||||
"D_k = 300 # Hidden dimensions\n",
|
||||
"D_o = 10 # Output dimensions\n",
|
||||
"\n",
|
||||
"model = nn.Sequential(\n",
|
||||
"nn.Linear(D_i, D_k),\n",
|
||||
"nn.ReLU(),\n",
|
||||
"nn.Linear(D_k, D_k),\n",
|
||||
"nn.ReLU(),\n",
|
||||
"nn.Linear(D_k, D_k),\n",
|
||||
"nn.ReLU(),\n",
|
||||
"nn.Linear(D_k, D_k),\n",
|
||||
"nn.ReLU(),\n",
|
||||
"nn.Linear(D_k, D_o))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "FslroPJJffrh"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# He initialization of weights\n",
|
||||
"def weights_init(layer_in):\n",
|
||||
" if isinstance(layer_in, nn.Linear):\n",
|
||||
" nn.init.kaiming_uniform_(layer_in.weight)\n",
|
||||
" layer_in.bias.data.fill_(0.0)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "YgLaex1pfhqz"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def train_model(train_data_x, train_data_y, n_epoch):\n",
|
||||
" # choose cross entropy loss function (equation 5.24 in the loss notes)\n",
|
||||
" loss_function = nn.CrossEntropyLoss()\n",
|
||||
" # construct SGD optimizer and initialize learning rate and momentum\n",
|
||||
" optimizer = torch.optim.SGD(model.parameters(), lr = 0.02, momentum=0.9)\n",
|
||||
" # object that decreases learning rate by half every 20 epochs\n",
|
||||
" scheduler = StepLR(optimizer, step_size=20, gamma=0.5)\n",
|
||||
" # create 100 dummy data points and store in data loader class\n",
|
||||
" x_train = torch.tensor(train_data_x.transpose().astype('float32'))\n",
|
||||
" y_train = torch.tensor(train_data_y.astype('long'))\n",
|
||||
"\n",
|
||||
" # load the data into a class that creates the batches\n",
|
||||
" data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
|
||||
"\n",
|
||||
" # Initialize model weights\n",
|
||||
" model.apply(weights_init)\n",
|
||||
"\n",
|
||||
" # store the loss and the % correct at each epoch\n",
|
||||
" losses_train = np.zeros((n_epoch))\n",
|
||||
"\n",
|
||||
" for epoch in range(n_epoch):\n",
|
||||
" # loop over batches\n",
|
||||
" for i, data in enumerate(data_loader):\n",
|
||||
" # retrieve inputs and labels for this batch\n",
|
||||
" x_batch, y_batch = data\n",
|
||||
" # zero the parameter gradients\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" # forward pass -- calculate model output\n",
|
||||
" pred = model(x_batch)\n",
|
||||
" # compute the loss\n",
|
||||
" loss = loss_function(pred, y_batch)\n",
|
||||
" # backward pass\n",
|
||||
" loss.backward()\n",
|
||||
" # SGD update\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" # Run whole dataset to get statistics -- normally wouldn't do this\n",
|
||||
" pred_train = model(x_train)\n",
|
||||
" _, predicted_train_class = torch.max(pred_train.data, 1)\n",
|
||||
" losses_train[epoch] = loss_function(pred_train, y_train).item()\n",
|
||||
" if epoch % 5 == 0:\n",
|
||||
" clear_output(wait=True)\n",
|
||||
" display(\"Epoch %d, train loss %3.3f\"%(epoch, losses_train[epoch]))\n",
|
||||
"\n",
|
||||
" # tell scheduler to consider updating learning rate\n",
|
||||
" scheduler.step()\n",
|
||||
"\n",
|
||||
" return losses_train"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "NYw8I_3mmX5c"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Load in the data\n",
|
||||
"train_data_x = data['x'].transpose()\n",
|
||||
"train_data_y = data['y']\n",
|
||||
"# Print out sizes\n",
|
||||
"print(\"Train data: %d examples (columns), each of which has %d dimensions (rows)\"%((train_data_x.shape[1],train_data_x.shape[0])))"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "4FE3HQ_vedXO"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Compute loss for proper data and plot\n",
|
||||
"n_epoch = 60\n",
|
||||
"loss_true_labels = train_model(train_data_x, train_data_y, n_epoch)\n",
|
||||
"# Plot the results\n",
|
||||
"fig, ax = plt.subplots()\n",
|
||||
"ax.plot(loss_true_labels,'r-',label='true_labels')\n",
|
||||
"# ax.set_ylim(0,0.7); ax.set_xlim(0,n_epoch)\n",
|
||||
"ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')\n",
|
||||
"ax.legend()\n",
|
||||
"plt.show()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "b56wdODqemF1"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# TODO -- Randomize the input data (train_data_x), but retain overall mean and variance\n",
|
||||
"# Replace this line\n",
|
||||
"train_data_x_randomized = np.copy(train_data_x)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "SbPCiiUKgTLw"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Compute loss for true labels and plot\n",
|
||||
"n_epoch = 60\n",
|
||||
"loss_randomized_data = train_model(train_data_x_randomized, train_data_y, n_epoch)\n",
|
||||
"# Plot the results\n",
|
||||
"fig, ax = plt.subplots()\n",
|
||||
"ax.plot(loss_true_labels,'r-',label='true_labels')\n",
|
||||
"ax.plot(loss_randomized_data,'b-',label='random_data')\n",
|
||||
"# ax.set_ylim(0,0.7); ax.set_xlim(0,n_epoch)\n",
|
||||
"ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')\n",
|
||||
"ax.legend()\n",
|
||||
"plt.show()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "y7CcCJvvjLnn"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# TODO -- Permute the labels\n",
|
||||
"# Replace this line:\n",
|
||||
"train_data_y_permuted = np.copy(train_data_y)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ojaMTrzKj_74"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Compute loss for true labels and plot\n",
|
||||
"n_epoch = 60\n",
|
||||
"loss_permuted_labels = train_model(train_data_x, train_data_y_permuted, n_epoch)\n",
|
||||
"# Plot the results\n",
|
||||
"fig, ax = plt.subplots()\n",
|
||||
"ax.plot(loss_true_labels,'r-',label='true_labels')\n",
|
||||
"ax.plot(loss_randomized_data,'b-',label='random_data')\n",
|
||||
"ax.plot(loss_permuted_labels,'g-',label='random_labels')\n",
|
||||
"# ax.set_ylim(0,0.7); ax.set_xlim(0,n_epoch)\n",
|
||||
"ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')\n",
|
||||
"ax.legend()\n",
|
||||
"plt.show()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "LaYCSjyMo9LQ"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user