Revert "Remove duplicate weight initialization"
This reverts commit 87cf590af9.
This commit is contained in:
@@ -1,10 +1,25 @@
|
|||||||
{
|
{
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": [],
|
||||||
|
"include_colab_link": true
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"name": "python3",
|
||||||
|
"display_name": "Python 3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab_type": "text",
|
"id": "view-in-github",
|
||||||
"id": "view-in-github"
|
"colab_type": "text"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap09/9_5_Augmentation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap09/9_5_Augmentation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||||
@@ -12,9 +27,6 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
|
||||||
"id": "el8l05WQEO46"
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"# **Notebook 9.5: Augmentation**\n",
|
"# **Notebook 9.5: Augmentation**\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -23,27 +35,25 @@
|
|||||||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n"
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "el8l05WQEO46"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "syvgxgRr3myY"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Run this if you're in a Colab to install MNIST 1D repository\n",
|
"# Run this if you're in a Colab to install MNIST 1D repository\n",
|
||||||
"!pip install git+https://github.com/greydanus/mnist1d"
|
"!pip install git+https://github.com/greydanus/mnist1d"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "syvgxgRr3myY"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "ckrNsYd13pMe"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import torch, torch.nn as nn\n",
|
"import torch, torch.nn as nn\n",
|
||||||
"from torch.utils.data import TensorDataset, DataLoader\n",
|
"from torch.utils.data import TensorDataset, DataLoader\n",
|
||||||
@@ -52,15 +62,15 @@
|
|||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"import mnist1d\n",
|
"import mnist1d\n",
|
||||||
"import random"
|
"import random"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "ckrNsYd13pMe"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "D_Woo9U730lZ"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"args = mnist1d.data.get_dataset_args()\n",
|
"args = mnist1d.data.get_dataset_args()\n",
|
||||||
"data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)\n",
|
"data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)\n",
|
||||||
@@ -70,15 +80,15 @@
|
|||||||
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
|
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
|
||||||
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
|
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
|
||||||
"print(\"Length of each example: {}\".format(data['x'].shape[-1]))"
|
"print(\"Length of each example: {}\".format(data['x'].shape[-1]))"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "D_Woo9U730lZ"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "JfIFWFIL33eF"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"D_i = 40 # Input dimensions\n",
|
"D_i = 40 # Input dimensions\n",
|
||||||
"D_k = 200 # Hidden dimensions\n",
|
"D_k = 200 # Hidden dimensions\n",
|
||||||
@@ -99,17 +109,17 @@
|
|||||||
" nn.init.kaiming_uniform_(layer_in.weight)\n",
|
" nn.init.kaiming_uniform_(layer_in.weight)\n",
|
||||||
" layer_in.bias.data.fill_(0.0)\n",
|
" layer_in.bias.data.fill_(0.0)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Initialize model weights\n",
|
"# Call the function you just defined\n",
|
||||||
"model.apply(weights_init)"
|
"model.apply(weights_init)"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "JfIFWFIL33eF"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "YFfVbTPE4BkJ"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# choose cross entropy loss function (equation 5.24)\n",
|
"# choose cross entropy loss function (equation 5.24)\n",
|
||||||
"loss_function = torch.nn.CrossEntropyLoss()\n",
|
"loss_function = torch.nn.CrossEntropyLoss()\n",
|
||||||
@@ -126,6 +136,9 @@
|
|||||||
"# load the data into a class that creates the batches\n",
|
"# load the data into a class that creates the batches\n",
|
||||||
"data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
|
"data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Initialize model weights\n",
|
||||||
|
"model.apply(weights_init)\n",
|
||||||
|
"\n",
|
||||||
"# loop over the dataset n_epoch times\n",
|
"# loop over the dataset n_epoch times\n",
|
||||||
"n_epoch = 50\n",
|
"n_epoch = 50\n",
|
||||||
"# store the loss and the % correct at each epoch\n",
|
"# store the loss and the % correct at each epoch\n",
|
||||||
@@ -156,15 +169,15 @@
|
|||||||
" errors_train[epoch] = 100 - 100 * (predicted_train_class == y_train).float().sum() / len(y_train)\n",
|
" errors_train[epoch] = 100 - 100 * (predicted_train_class == y_train).float().sum() / len(y_train)\n",
|
||||||
" errors_test[epoch]= 100 - 100 * (predicted_test_class == y_test).float().sum() / len(y_test)\n",
|
" errors_test[epoch]= 100 - 100 * (predicted_test_class == y_test).float().sum() / len(y_test)\n",
|
||||||
" print(f'Epoch {epoch:5d}, train error {errors_train[epoch]:3.2f}, test error {errors_test[epoch]:3.2f}')"
|
" print(f'Epoch {epoch:5d}, train error {errors_train[epoch]:3.2f}, test error {errors_test[epoch]:3.2f}')"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "YFfVbTPE4BkJ"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "FmGDd4vB8LyM"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Plot the results\n",
|
"# Plot the results\n",
|
||||||
"fig, ax = plt.subplots()\n",
|
"fig, ax = plt.subplots()\n",
|
||||||
@@ -175,24 +188,24 @@
|
|||||||
"ax.set_title('Train Error %3.2f, Test Error %3.2f'%(errors_train[-1],errors_test[-1]))\n",
|
"ax.set_title('Train Error %3.2f, Test Error %3.2f'%(errors_train[-1],errors_test[-1]))\n",
|
||||||
"ax.legend()\n",
|
"ax.legend()\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "FmGDd4vB8LyM"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
|
||||||
"id": "55XvoPDO8Qp-"
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"The best test performance is about 33%. Let's see if we can improve on that by augmenting the data."
|
"The best test performance is about 33%. Let's see if we can improve on that by augmenting the data."
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "55XvoPDO8Qp-"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "IP6z2iox8MOF"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"def augment(input_vector):\n",
|
"def augment(input_vector):\n",
|
||||||
" # Create output vector\n",
|
" # Create output vector\n",
|
||||||
@@ -208,15 +221,15 @@
|
|||||||
" data_out = np.array(data_out)\n",
|
" data_out = np.array(data_out)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return data_out"
|
" return data_out"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "IP6z2iox8MOF"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "bzN0lu5J95AJ"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"n_data_orig = data['x'].shape[0]\n",
|
"n_data_orig = data['x'].shape[0]\n",
|
||||||
"# We'll double the amount of data\n",
|
"# We'll double the amount of data\n",
|
||||||
@@ -234,15 +247,15 @@
|
|||||||
" # Augment the point and store\n",
|
" # Augment the point and store\n",
|
||||||
" augmented_x[c_augment,:] = augment(data['x'][random_data_index,:])\n",
|
" augmented_x[c_augment,:] = augment(data['x'][random_data_index,:])\n",
|
||||||
" augmented_y[c_augment] = data['y'][random_data_index]\n"
|
" augmented_y[c_augment] = data['y'][random_data_index]\n"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "bzN0lu5J95AJ"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "hZUNrXpS_kRs"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# choose cross entropy loss function (equation 5.24)\n",
|
"# choose cross entropy loss function (equation 5.24)\n",
|
||||||
"loss_function = torch.nn.CrossEntropyLoss()\n",
|
"loss_function = torch.nn.CrossEntropyLoss()\n",
|
||||||
@@ -292,15 +305,15 @@
|
|||||||
" errors_train_aug[epoch] = 100 - 100 * (predicted_train_class == y_train).float().sum() / len(y_train)\n",
|
" errors_train_aug[epoch] = 100 - 100 * (predicted_train_class == y_train).float().sum() / len(y_train)\n",
|
||||||
" errors_test_aug[epoch]= 100 - 100 * (predicted_test_class == y_test).float().sum() / len(y_test)\n",
|
" errors_test_aug[epoch]= 100 - 100 * (predicted_test_class == y_test).float().sum() / len(y_test)\n",
|
||||||
" print(f'Epoch {epoch:5d}, train error {errors_train_aug[epoch]:3.2f}, test error {errors_test_aug[epoch]:3.2f}')"
|
" print(f'Epoch {epoch:5d}, train error {errors_train_aug[epoch]:3.2f}, test error {errors_test_aug[epoch]:3.2f}')"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "hZUNrXpS_kRs"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "IcnAW4ixBnuc"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Plot the results\n",
|
"# Plot the results\n",
|
||||||
"fig, ax = plt.subplots()\n",
|
"fig, ax = plt.subplots()\n",
|
||||||
@@ -312,31 +325,21 @@
|
|||||||
"ax.set_title('TrainError %3.2f, Test Error %3.2f'%(errors_train_aug[-1],errors_test_aug[-1]))\n",
|
"ax.set_title('TrainError %3.2f, Test Error %3.2f'%(errors_train_aug[-1],errors_test_aug[-1]))\n",
|
||||||
"ax.legend()\n",
|
"ax.legend()\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "IcnAW4ixBnuc"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
|
||||||
"id": "jgsR7ScJHc9b"
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"Hopefully, you should see an improvement in performance when we augment the data."
|
"Hopefully, you should see an improvement in performance when we augment the data."
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"id": "jgsR7ScJHc9b"
|
||||||
"include_colab_link": true,
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"nbformat": 4,
|
]
|
||||||
"nbformat_minor": 0
|
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user