Add CUDA support to notebook 10.5
This commit is contained in:
committed by
GitHub
parent
75646c2c8e
commit
cc9c695ff7
@@ -1,26 +1,10 @@
|
|||||||
{
|
{
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"provenance": [],
|
|
||||||
"authorship_tag": "ABX9TyORZF8xy4X1yf4oRhRq8Rtm",
|
|
||||||
"include_colab_link": true
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"name": "python3",
|
|
||||||
"display_name": "Python 3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "view-in-github",
|
"colab_type": "text",
|
||||||
"colab_type": "text"
|
"id": "view-in-github"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap10/10_5_Convolution_For_MNIST.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap10/10_5_Convolution_For_MNIST.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||||
@@ -28,6 +12,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "t9vk9Elugvmi"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# **Notebook 10.5: Convolution for MNIST**\n",
|
"# **Notebook 10.5: Convolution for MNIST**\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -37,14 +24,18 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"If you are using Google Colab, you can change your runtime to an instance with GPU support to speed up training, e.g. a T4 GPU. If you do this, the cell below should output ``device(type='cuda')``\n",
|
||||||
|
"\n",
|
||||||
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n"
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "t9vk9Elugvmi"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "YrXWAH7sUWvU"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"import torchvision\n",
|
"import torchvision\n",
|
||||||
@@ -52,16 +43,18 @@
|
|||||||
"import torch.nn.functional as F\n",
|
"import torch.nn.functional as F\n",
|
||||||
"import torch.optim as optim\n",
|
"import torch.optim as optim\n",
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"import random"
|
"import random\n",
|
||||||
],
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
"metadata": {
|
"device"
|
||||||
"id": "YrXWAH7sUWvU"
|
]
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "wScBGXXFVadm"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Run this once to load the train and test data straight into a dataloader class\n",
|
"# Run this once to load the train and test data straight into a dataloader class\n",
|
||||||
"# that will provide the batches\n",
|
"# that will provide the batches\n",
|
||||||
@@ -73,7 +66,7 @@
|
|||||||
"batch_size_train = 64\n",
|
"batch_size_train = 64\n",
|
||||||
"batch_size_test = 1000\n",
|
"batch_size_test = 1000\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# TODO Change this directory to point towards an existing directory\n",
|
"# TODO Change this directory to point towards an existing directory (No change needed if using Google Colab)\n",
|
||||||
"myDir = '/files/'\n",
|
"myDir = '/files/'\n",
|
||||||
"\n",
|
"\n",
|
||||||
"train_loader = torch.utils.data.DataLoader(\n",
|
"train_loader = torch.utils.data.DataLoader(\n",
|
||||||
@@ -93,15 +86,15 @@
|
|||||||
" (0.1307,), (0.3081,))\n",
|
" (0.1307,), (0.3081,))\n",
|
||||||
" ])),\n",
|
" ])),\n",
|
||||||
" batch_size=batch_size_test, shuffle=True)"
|
" batch_size=batch_size_test, shuffle=True)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "wScBGXXFVadm"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "8bKADvLHbiV5"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Let's draw some of the training data\n",
|
"# Let's draw some of the training data\n",
|
||||||
"examples = enumerate(test_loader)\n",
|
"examples = enumerate(test_loader)\n",
|
||||||
@@ -112,28 +105,27 @@
|
|||||||
" plt.subplot(2,3,i+1)\n",
|
" plt.subplot(2,3,i+1)\n",
|
||||||
" plt.tight_layout()\n",
|
" plt.tight_layout()\n",
|
||||||
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n",
|
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n",
|
||||||
" plt.title(\"Ground Truth: {}\".format(example_targets[i]))\n",
|
" plt.title(\"Ground Truth: {}\".format(example_targe plt.xticks([])\n",
|
||||||
" plt.xticks([])\n",
|
|
||||||
" plt.yticks([])\n",
|
" plt.yticks([])\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "8bKADvLHbiV5"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"Define the network. This is a more typical way to define a network than the sequential structure. We define a class for the network, and define the parameters in the constructor. Then we use a function called forward to actually run the network. It's easy to see how you might use residual connections in this format."
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "_sFvRDGrl4qe"
|
"id": "_sFvRDGrl4qe"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"Define the network. This is a more typical way to define a network than the sequential structure. We define a class for the network, and define the parameters in the constructor. Then we use a function called forward to actually run the network. It's easy to see how you might use residual connections in this format."
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "EQkvw2KOPVl7"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from os import X_OK\n",
|
"from os import X_OK\n",
|
||||||
"# TODO Change this class to implement\n",
|
"# TODO Change this class to implement\n",
|
||||||
@@ -174,52 +166,54 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n"
|
"\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "EQkvw2KOPVl7"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "qWZtkCZcU_dg"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# He initialization of weights\n",
|
"# He initialization of weights\n",
|
||||||
"def weights_init(layer_in):\n",
|
"def weights_init(layer_in):\n",
|
||||||
" if isinstance(layer_in, nn.Linear):\n",
|
" if isinstance(layer_in, nn.Linear):\n",
|
||||||
" nn.init.kaiming_uniform_(layer_in.weight)\n",
|
" nn.init.kaiming_uniform_(layer_in.weight)\n",
|
||||||
" layer_in.bias.data.fill_(0.0)"
|
" layer_in.bias.data.fill_(0.0)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "qWZtkCZcU_dg"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "FslroPJJffrh"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Create network\n",
|
"# Create network\n",
|
||||||
"model = Net()\n",
|
"model = Net().to(device)\n",
|
||||||
"# Initialize model weights\n",
|
"# Initialize model weights\n",
|
||||||
"model.apply(weights_init)\n",
|
"model.apply(weights_init)\n",
|
||||||
"# Define optimizer\n",
|
"# Define optimizer\n",
|
||||||
"optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)"
|
"optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "FslroPJJffrh"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "xKQd9PzkQ766"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Main training routine\n",
|
"# Main training routine\n",
|
||||||
"def train(epoch):\n",
|
"def train(epoch):\n",
|
||||||
" model.train()\n",
|
" model.train()\n",
|
||||||
" # Get each\n",
|
" # Get each\n",
|
||||||
" for batch_idx, (data, target) in enumerate(train_loader):\n",
|
" for batch_idx, (data, target) in enumerate(train_loader):\n",
|
||||||
|
" data = data.to(device)\n",
|
||||||
|
" target = target.to(device)\n",
|
||||||
" optimizer.zero_grad()\n",
|
" optimizer.zero_grad()\n",
|
||||||
" output = model(data)\n",
|
" output = model(data)\n",
|
||||||
" loss = F.nll_loss(output, target)\n",
|
" loss = F.nll_loss(output, target)\n",
|
||||||
@@ -229,15 +223,15 @@
|
|||||||
" if batch_idx % 10 == 0:\n",
|
" if batch_idx % 10 == 0:\n",
|
||||||
" print('Train Epoch: {} [{}/{}]\\tLoss: {:.6f}'.format(\n",
|
" print('Train Epoch: {} [{}/{}]\\tLoss: {:.6f}'.format(\n",
|
||||||
" epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()))"
|
" epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()))"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "xKQd9PzkQ766"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "Byn-f7qWRLxX"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Run on test data\n",
|
"# Run on test data\n",
|
||||||
"def test():\n",
|
"def test():\n",
|
||||||
@@ -246,6 +240,8 @@
|
|||||||
" correct = 0\n",
|
" correct = 0\n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad():\n",
|
||||||
" for data, target in test_loader:\n",
|
" for data, target in test_loader:\n",
|
||||||
|
" data = data.to(device)\n",
|
||||||
|
" target = target.to(device)\n",
|
||||||
" output = model(data)\n",
|
" output = model(data)\n",
|
||||||
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
|
||||||
" pred = output.data.max(1, keepdim=True)[1]\n",
|
" pred = output.data.max(1, keepdim=True)[1]\n",
|
||||||
@@ -254,15 +250,15 @@
|
|||||||
" print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
|
" print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
|
||||||
" test_loss, correct, len(test_loader.dataset),\n",
|
" test_loss, correct, len(test_loader.dataset),\n",
|
||||||
" 100. * correct / len(test_loader.dataset)))"
|
" 100. * correct / len(test_loader.dataset)))"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "Byn-f7qWRLxX"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "YgLaex1pfhqz"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Get initial performance\n",
|
"# Get initial performance\n",
|
||||||
"test()\n",
|
"test()\n",
|
||||||
@@ -271,15 +267,15 @@
|
|||||||
"for epoch in range(1, n_epochs + 1):\n",
|
"for epoch in range(1, n_epochs + 1):\n",
|
||||||
" train(epoch)\n",
|
" train(epoch)\n",
|
||||||
" test()"
|
" test()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "YgLaex1pfhqz"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "o7fRUAy9Se1B"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Run network on data we got before and show predictions\n",
|
"# Run network on data we got before and show predictions\n",
|
||||||
"output = model(example_data)\n",
|
"output = model(example_data)\n",
|
||||||
@@ -294,12 +290,23 @@
|
|||||||
" plt.xticks([])\n",
|
" plt.xticks([])\n",
|
||||||
" plt.yticks([])\n",
|
" plt.yticks([])\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "o7fRUAy9Se1B"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"authorship_tag": "ABX9TyORZF8xy4X1yf4oRhRq8Rtm",
|
||||||
|
"include_colab_link": true,
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user