{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyM4FAP7pqe8LpKfmixZN07G", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# Convolution III -- MNIST\n", "\n", "This notebook builds a proper network for 2D convolution. It works with the MNIST dataset, which was the original classic dataset for classifying images. The network will take a 28x28 grayscale image and classify it into one of 10 classes representing a digit.\n" ], "metadata": { "id": "t9vk9Elugvmi" } }, { "cell_type": "code", "source": [ "import torch\n", "import torchvision\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import matplotlib.pyplot as plt\n", "import random" ], "metadata": { "id": "YrXWAH7sUWvU" }, "execution_count": 1, "outputs": [] }, { "cell_type": "code", "source": [ "# Run this once to load the train and test data straight into a dataloader class\n", "# that will provide the batches\n", "batch_size_train = 64\n", "batch_size_test = 1000\n", "train_loader = torch.utils.data.DataLoader(\n", " torchvision.datasets.MNIST('/files/', train=True, download=True,\n", " transform=torchvision.transforms.Compose([\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(\n", " (0.1307,), (0.3081,))\n", " ])),\n", " batch_size=batch_size_train, shuffle=True)\n", "\n", "test_loader = torch.utils.data.DataLoader(\n", " torchvision.datasets.MNIST('/files/', train=False, download=True,\n", " transform=torchvision.transforms.Compose([\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(\n", " (0.1307,), (0.3081,))\n", " ])),\n", " batch_size=batch_size_test, shuffle=True)" ], "metadata": { "id": "wScBGXXFVadm" }, "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "# Let's draw some of the training data\n", "examples = enumerate(test_loader)\n", "batch_idx, (example_data, example_targets) = next(examples)\n", "\n", "fig = plt.figure()\n", "for i in range(6):\n", " plt.subplot(2,3,i+1)\n", " plt.tight_layout()\n", " plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n", " plt.title(\"Ground Truth: {}\".format(example_targets[i]))\n", " plt.xticks([])\n", " plt.yticks([])\n", "plt.show()" ], "metadata": { "id": "8bKADvLHbiV5", "colab": { "base_uri": "https://localhost:8080/", "height": 284 }, "outputId": "ec39acb8-05e0-48ae-cb6e-1b8578d00e68" }, "execution_count": 3, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deZRUxdnH8V+JC4uKKCJuoIiCG2KEmBcEN4youCHqy/tGNBESlSOeuAS3aHBFUBHj9qIoERUxrIIaFRVPcEFwFyTqQQTBsEqigFuo949urlUl3dNL9cyd4fs5Z47PY92+t2am6Gfurdt1jbVWAACUa7Oa7gAAoG6goAAAoqCgAACioKAAAKKgoAAAoqCgAACiqNMFxRizhzHGGmM2r4FjLzDGdKvu4yIOxg5KtSmPnbILijHmv40xM40xa4wxy7LxBcYYE6ODlWKM+dr5Wm+MWefk/1vkvkYZY26I2LcjjTHvG2NWG2NWGmMmGmN2jbX/tGDsxB872X3uaIx5zBjzL2PMl8aYR2PuPw0YOxV53zHGmKuMMQuNMf82xjxujNm2mH2UVVCMMZdIGi5pqKTmknaSdJ6kzpK2zPGaeuUcMxZr7dYbviQtlHSi8/+Sf4A18VeGpLmSjrXWbidpF0kfS7q3BvpRMYydipog6Z+SWkhqJunWGupHRTB2KqaPpLOU+TnuIqmBpD8XtQdrbUlfkhpLWiPptCq2G6XMm+HT2e27SdpX0nRJqyXNkXSSs/10SX2d/BxJM5zcKjN4Ps6+/m5JJttWT5l/PCskzZfUP7v95lX0cYGkbtn4CEmfSxqozD/K0WEfnH60lvRbSd9L+k7S15KmOPu8VNJ7kv4laayk+iX8nLeSdLOkuaX+rtL2xdip3NiR9Mvs6+vV9O+ZsVPrxs44SZc5eSdJ30hqWOjvp5wzlP9S5s1ucgHb/o+kGyVtI2mmpCmSnlPmr6cLJT1qjGlTxLF7SOooqZ2kMyQdm/3//bJtB0vqIKlXEft0NZe0vaSWyvzicrLWjpD0qKQhNvNXxolO8xmSukvaM9vXczY0ZC9nHZZrv8aYFsaY1ZLWKTNAhpT2raQSY0cVGzu/kPQPSX/JXi6dZYw5vMTvJY0YO6rc+44kE8RbSdq70G+gnILSVNIKa+0PydGNeTXb4XXGmK7OtpOtta9Ya9dLai9pa0mDrbXfWWtflDRVUu8ijj3YWrvaWrtQ0kvZfUqZH+Qd1tpF1tpVyvxlX4r1kq611n5rrV1X4j4k6U5r7ZJsX6Y4/ZS1djtr7YxcL7TWLrSZS15NJV0taV4Z/Ugbxk7VSh07uylzlvKSMm9Qt0mabIxpWkZf0oSxU7VSx87fJPXN3lTQWJmzJUlqWOiByykoKyU1da/1WWs7Zd8EVwb7XuTEu0halP0lb/CZpGImnf/pxGuVGSjJvoP9lmK5tfabEl/rytXPgmUHxV+UeVOoqWvysTF2qlbq2FknaYG1dqS19ntr7ePKfF+dI/QpDRg7VSt17DwoaYwyl//mKFM0pcyluIKUU1Bek/StpJML2NZd0niJpN2NMe6xW0hanI3XyK+IzYvo0xeSdg/2W4pwCWavT8aYsE+VXrJ5c2VO04u64yLFGDu5ty/XexvZZ11aUpyxk3v7slhr11trr7XW7mGt3U2ZorJYP/6MqlRyQbHWrpY0SNI9xphexphtjDGbGWPaS2qU56UzlamafzDGbGGMOULSiZIez7a/I6mnMaahMaa1pHOL6NYTkgYYY3YzxjSRdHmR31Yu70ra3xjT3hhTX9KfgvalklpFOpaMMT2NMW2yP88dJd0u6e3s2Uqtx9jxRB07kiZKamKMOdsYU88Y00uZy2CvRDxGjWHseGK/72xvjNkre/vwfsq871wXnNXlVdZtw9baIZIulvQHZb65pZL+T5lrb6/meM13yvwij1Pmroh7JPWx1m6YIximzJ0LS5W51FPMPfT3S3pWmV/EW8rcPlk2a+1Hkq6TNE2ZuzzCa5AjJe2XvY47qZB9Zu8775KjeVdlrmd+Jel9Za6tnlpK39OKsZOIOnayf3ScpMyNHP9S5s3tZGvtihK/hdRh7CRiv+801Y93xT0j6cHs5H/BNtz2BgBAWer00isAgOpDQQEAREFBAQBEQUEBAERBQQEARFHUJ6+NMdwSlkLW2rQv2c24SacV1toda7oT+TB2UmujY4czFGDTVeoSIcBGxw4FBQAQBQUFABAFBQUAEAUFBQAQBQUFABAFBQUAEAUFBQAQBQUFABBFXXlGOVAtDj/8cC8fNmxYEr/wwgte22WXXVYtfQLSgjMUAEAUFBQAQBQUFABAFJvMHMruu+/u5VdeeaWXn3feeUl81113eW0XXnhh5TqGWqVHjx5e3qZNmyR+/vnnq7s7QKpwhgIAiIKCAgCIok5d8jr77LO9/Kqrrkrili1bem2bb+5/6+vXr0/i888/32vbbDO/7vbv37+sfqL2CMdNnz59vHz48OFJHF5GBTY1nKEAAKKgoAAAoqCgAACiqHVzKO48ydVXX+217bHHHl4+bty4JL733nu9tvvvv9/LDzjggCR+5ZVXvLZjjz22pL6i9nNvJ5ekhg0bejnzJnC54+PMM8/02i6//HIvX7duXRJPnDjRa7vpppu8/Pvvv4/VxYriDAUAEAUFBQAQReouebVt29bLjzzySC+/5pprknjt2rVe24gRI7z8tttuS+L58+fnPW64L1d462jv3r2TeMyYMXn3i9rn5z//eRKHt6IvWrSouruDFNtqq628/MEHH0zi008/3Wszxni5tTaJDzzwQK9thx128PIBAwaU1c/qwhkKACAKCgoAIAoKCgAgilTMobz//vtJvOuuu3ptjRs39vJPP/00ie+44w6vLVwluBhffvllEi9YsMBrC29Hrl+/fsnHQfo0aNDAyydPnpzEO+64o9c2ffr06ugSUqp169ZefuONN3p5r169cr525MiRXv6b3/wm57Z77713Cb2reZyhAACioKAAAKKgoAAAoqi2OZT9998/icePH++17bXXXkkcLhXvzplI/hPz5s2bF61/X3/9dRIvX77cawvnUFC3hMuruPMmc+fO9douvfTSko+z5ZZbJnGLFi28Nnfpn0mTJpV8DMT3y1/+MonD965wKR7XtGnTvHzUqFFenm8O5ZZbbimih+nBGQoAIAoKCgAgimq75OXebrfnnnt6be5lrvCU8qKLLvLyL774ogK9k7p3757EHTt2rMgxkA7bbbedl5988sk5tw3H35IlSwo+zp/+9CcvP+mkk5K4Xbt2Xpu7hA+XvGqWe1ld8lcmDy9xhbcCu5e1PvjgA68tXF4ln3C5p9qCMxQAQBQUFABAFBQUAEAU1TaH4i5n0b9/f6/NvU1z9OjRXlul5kzCJRLCJzqi7jruuOO8/LDDDvPyr776KolXrlyZd1+NGjVK4mHDhnlt4RP7Fi5cmMQzZszw2o4++ugk/vWvf+21PfTQQ3n7gLjCuY7mzZsn8cUXX+y1hb/zfD755BMvX7p06UaPIUnNmjUreL9pwhkKACAKCgoAIAoKCgAgihpZvv6BBx6oluO4j9EM50wGDx7s5dtss03B+w2XsEbtctBBB+Vtd+f73nvvvbzbustyhEtpzJkzx8vdeZJwDLlzKq1atcp7TFTW8ccf7+XuY58feeSRkvcbzt01adIkid3HAddmnKEAAKKgoAAAokjFExtLFT5Nr2fPnl7u3p7srnZcLvfWwVmzZnltLJuRTu6K0X369PHa1q5d6+XnnHNOEru3BUvSY4895uXucirhfsKn+a1YsSKJw0texpgcPUdNc28jX7duXcGvC5dPCZficVefris4QwEAREFBAQBEQUEBAERRq+dQOnXq5OX33HNPyftyn8wXLqGf74l+4ZIJSKdzzz03icNlLT788EMv33rrrZM4XPbkhBNO8HJ33qRv375e2xNPPJGzP+F+3NtGp06dmvN1qLzVq1d7eefOnZN4v/3289reeOMNL3fnxp555hmvLXxKZ125VdjFGQoAIAoKCgAgCgoKACCKWj2HEl7rnD17tpe3bds2iV977TWvLVx2+u23307iZcuWeW2vvvqql7vXRs866yyv7b777quq26gGDRo08PJu3brl3Dacs3j44YeT2P2cifTT5VROP/30JP7oo4/y9qlr165JfMEFF3ht8+fPT2J3Pg/Vb8KECV7uzneFS6/cfPPNXn7llVcmcbiEjrtcvSRttdVWSdy4cePSOpsynKEAAKKgoAAAoqjVl7xefvllLz/00EO93H0S37x587w2dxmMqoRP13vhhReS+JBDDvHamjZtWtIxENe2227r5R07dsy57ZFHHunlHTp0SOK33nrLawtXoi3md3zMMcckccOGDb22K664IondpT5Q/fI9IXPkyJF5c1e4UvUZZ5zh5aecckoSh6ufh08RHTp0aM7jpAlnKACAKCgoAIAoKCgAgChq9RxKVcK5j1KFy5IvXLgwid2n8EnSgAEDkviaa66JcnwUr3v37gVv686ZhAYOHOjlxcyZhEu8uLefPv30017biBEjCt4vqtfjjz+exOGS9OESOu5cSPg7zrf0fbgMS219KixnKACAKCgoAIAoKCgAgCjq9BxKpbjLkvfq1ctrc6+5z5w502t76qmnKtsxJGJdg27fvr2Xv/TSSwW/Nlymo127dkk8ZMiQ8jqGauPOfYSP8Q3zYrhzsaFtttnGy7fffvskXrVqVcnHrDTOUAAAUVBQAABRcMmrBO4yLkuWLPHa2rRpk8Th09245FV93nzzzYK3/d3vfufl7rItd999d97X7rLLLkl89dVXe22HH364l7urVE+aNKng/qFumjx5chKHK0yH7x3uqubDhw+vbMfKwBkKACAKCgoAIAoKCgAgCuZQNmLzzf0fyx577OHlTz75ZBK7cyaSv4RCvqUWUFnPPvusl8+aNSuJw6Xs+/Xr5+V9+/ZN4n322cdrO/jgg7181KhRSRwun/Hll196+aBBg5L4m2++ydV1bCLc94c777zTawuf/NqzZ88kZg4FAFDnUVAAAFGk4pKX+6nQH374wWsL8wYNGpR0jBYtWnj5ueeem3PbJk2aeLl7y14ovMzhfjr6rrvuKqaLiCi83Dht2rQkDi95hasNv/POOwUfx12J2j2GJN16661ePnv27IL3i03LsmXLvNwY4+VdunSpzu6UjDMUAEAUFBQAQBQUFABAFKmYQ5k/f34SL1682GtbtGiRlx9//PHV0qd83NtBw9VGmTdJpxtuuCGJwyd59ujRw8vPP//8nPtxbxOW/OUzpkyZUkYPsSn7+OOPvXzNmjVe3rBhwySuX7++15amW9A5QwEAREFBAQBEQUEBAERhws9R5N3YmMI3LsIll1ySxJ06dfLawieXHX300ZXogmfs2LFePmfOHC+fOnVqEr/77rsV709VrLWm6q1qTqXGDcr2prW2Q9Wb1ZxNdew89thjXn7mmWcmsbs0kCQ99NBD1dKnwEbHDmcoAIAoKCgAgChScdvwbbfdttEYADZFCxcurOkulIQzFABAFBQUAEAUFBQAQBSpmEMBAPzok08+ydkWLtOSJpyhAACioKAAAKKgoAAAokjF0isoD0uvoEQsvYJSsfQKAKByKCgAgCgoKACAKCgoAIAoKCgAgCgoKACAKIpdemWFpM8q0RGUrGVNd6AAjJt0YuygVBsdO0V9DgUAgFy45AUAiIKCAgCIgoICAIiCggIAiIKCAgCIgoICAIiCggIAiIKCAgCIgoICAIiCggIAiIKCAgCIgoICAIiCggIAiKJOFxRjzB7GGGuMKXaZ/hjHXmCM6Vbdx0UcjB2UalMeO2UXFGPMfxtjZhpj1hhjlmXjC4wxJkYHK8UY87Xztd4Ys87J/7fIfY0yxtwQsW9HZPvk9vHsWPtPC8ZORcaOMcZcZYxZaIz5tzHmcWPMtrH2nxaMnfhjJ7vPC40xn2bHzmxjzGHFvL6sgmKMuUTScElDJTWXtJOk8yR1lrRljtfUK+eYsVhrt97wJWmhpBOd//fohu1q4q+MrCVuH621f6mhflQEY6di+kg6S5mf4y6SGkj6cw30o2IYO5VhjDlU0mBJvSQ1ljRS0sSifnbW2pK+sgdcI+m0KrYbJeleSU9nt+8maV9J0yWtljRH0knO9tMl9XXycyTNcHKrzOD5OPv6u/Xjg8LqSbpVmae8zZfUP7v95lX0cYGkbtn4CEmfSxoo6Z+SRod9cPrRWtJvJX0v6TtJX0ua4uzzUknvSfqXpLGS6hf4sz1C0uel/m7S/sXYqejYGSfpMifvJOkbSQ1r+vfO2En92DlT0htO3ih7vJ0L/f2Uc4byX5K2kjS5gG3/R9KNkraRNFPSFEnPSWom6UJJjxpj2hRx7B6SOkpqJ+kMScdm/3+/bNvBkjooU2lL0VzS9so85vK3+Ta01o6Q9KikITbzV8aJTvMZkrpL2jPb13M2NBhjVldxOtnMGLM0e/o5zBjTqLRvJZUYO6ro2DFBvJWkvYv4HtKMsaOKjZ1nJNUzxhyaPSv5jaR3lClwBSmnoDSVtMJa+4PT2VezHV5njOnqbDvZWvuKtXa9pPaStpY02Fr7nbX2RUlTJfUu4tiDrbWrrbULJb2U3aeU+UHeYa1dZK1dJenmEr+39ZKutdZ+a61dV+I+JOlOa+2SbF+mOP2UtXY7a+2MHK+bl912Z0lHSTpE0u1l9CNtGDtVK3Xs/E1S3+zEcGNl/uKVpIZl9CVNGDtVK3XsfCVpvKQZkr6VdK2k39rs6UohyikoKyU1da/1WWs7WWu3y7a5+17kxLtIWpT9JW/wmaRdizi2WzHXKjNQkn0H+y3FcmvtNyW+1pWrn3lZa/9prZ1rrV1vrf1U0h8knRahP2nB2KlaSWNH0oOSxihzCWeOMm98UuZySl3A2KlaqWPnXEm/lrS/MnNRv5I01RizS6EHLqegvKZMFTu5gG3dCrdE0u7GGPfYLSQtzsZr5P811byIPn0hafdgv6UIK7LXJ2NM2KeCK3gZ/alLt3gzdnJvX5bsHyHXWmv3sNbupkxRWawff0a1HWMn9/blai9pqrX2o+w4+psy31unQndQ8puUtXa1pEGS7jHG9DLGbGOM2cwY016ZyZxcZipTNf9gjNnCGHOEpBMlPZ5tf0dST2NMQ2NMa2WqZqGekDTAGLObMaaJpMuL/LZyeVfS/saY9saY+pL+FLQvldQq0rFkjDnSGNMyewvo7srceVHINeNagbHjiT12tjfG7JUdO/spc6n0uuAv81qLseOJOnYkzZJ0gjGmVXb8HCNpH0kfFLqDsv7qtdYOkXSxMpdklma//k+Z67av5njNd8r8Io9T5q6IeyT1sdbOy24yTJk7F5ZK+osyE0+Ful/Ss8r8It6SNKG472jjrLUfSbpO0jRl7vIIr0GOlLRf9jrupEL2mb3vvEuO5oOV+fmtyf73fUkDSul7WjF2ErHHTlP9eGfTM5IezE7g1hmMnUTssfOwMgV2uqR/S7pT0u+cn1HV+y9ivgUAgJzq0nV5AEANoqAAAKKgoAAAoqCgAACioKAAAKIoakVLYwy3hKWQtTbtS3YzbtJphbV2x5ruRD6MndTa6NjhDAXYdJW6RAiw0bFDQQEAREFBAQBEQUEBAERBQQEAREFBAQBEQUEBAERBQQEAREFBAQBEQUEBAERBQQEAREFBAQBEQUEBAERR1GrDALApa9asmZd/++23Xt62bdskPu200/Luy5gfFwlv2rSp19anT5+cr+vdu7eXP/HEE3mPU504QwEAREFBAQBEYawt/Pk1aXjYzU477ZTEnTp18tpOOOGEgvezzz77ePmyZcuSeNiwYV7bhx9+6OWrVq0q+DjVgQdsoURvWms71HQn8qmJsRO+N/z+979P4vBS1JIlS7x8zz33LPg47iWvYt6Hw/ejAw88sODXRrTRscMZCgAgCgoKACAKCgoAIIrU3za8ww47ePmzzz6bxPvtt5/XFt7Ct/nmP357a9eu9dqaNGni5e71zJ49e3ptixcv9vJBgwYl8QMPPJCz76gddtxxRy+//vrrk3jffff12rp06eLl7rXvGTNmeG0TJ0708jvuuKOsfqIywtt7R44c6eVbb711ErvvE1JxcyaxzJkzp9qPWSjOUAAAUVBQAABRpP6S1yGHHOLl7dq1S+LrrrvOa3vkkUe8fPvtt0/iuXPnem1HHXVUzmOGt+GFn46dPn167g4j9U499VQvv/322728RYsWSRzezpkvP+yww7y2zp07e3mjRo2S+MYbbyyix6ikfv36ebl7iSumcePGefn++++fxOGl1Xw+++yzaH2KjTMUAEAUFBQAQBQUFABAFKmfQzn66KNzti1atMjLP/nkk4L3++STT5bUhtqpa9euSTx+/HivLZwXcW8xv+mmm7y2lStX5jzGscce6+WnnHKKl1900UVJPHr0aK9t4cKFOfeL9ArnZrt165Zz2xUrVnj54MGDk7iYOZQ04wwFABAFBQUAEAUFBQAQRernUHr06JGzberUqdXYE9RmV1xxRRJX9dmSjh07JvG8efMKPoa7LJD008+huEu8hE/oYw6l5qxevdrL//Of/3i5u4TTZpv5f4OHn1E7/fTTk3jEiBFe2/r1673c/VxSuKRLPuHnptKEMxQAQBQUFABAFKl/YmN4q527nEp4+lmO1q1bJ/ExxxzjtYXLv7geeughL3/llVei9alQPLHxp7p37+7lTz31lNsfry1cBuWPf/xjScd0L2FIUtu2bb38zTffLGm/FcQTGzciHA8DBw50++O15Xv/dMecJE2bNs3L3dWnq3ofnjRpUhL/6le/8tq++eabvK+tEJ7YCACoHAoKACAKCgoAIIrUz6EsX77cy90nOBYzh9KqVSsvD+c+DjrooCSuV6+e17Zu3Tovd9vDPrjLb7zxxhsF968czKH81H333eflffv2TeJ//OMfXpt7m7D006d7Fiq8tj1q1CgvP+CAA5K4mNuRK4g5lI3Yeeedvfycc85J4htuuMFrK+b9M+TOx1S1n/bt2yfxBx98UPIxI2IOBQBQORQUAEAUFBQAQBSpX3olvHf7zDPPTOLw8yKffvqplx933HFJPHToUK/tiy++8PKxY8cm8fDhw722cIlq9xGhf/3rX7222267LYm7dOkipIN7vTqcIyl1zkTyl1O58sorcx5T8pesD+dtkB7he8PNN9+cxC+99JLX5i7pI0knnHBC5TpWC3CGAgCIgoICAIgi9Ze8Bg0a5OUnnXRSEk+YMCHva93VPadPn+619erVy8u//vrrgvvkbvvAAw94bWPGjEni3r1752xD9Srn9s583Esebdq0yXtMd4XhcLXhcIkhpNPrr7/u5eH7yKmnnprEjz32WLTj3nXXXUl87bXXem0vv/xytOOUizMUAEAUFBQAQBQUFABAFKlfeiXkXrPu1KmT1xYukeLObzz33HOV7ViWO28zZ84cr+3AAw+syDFZeuWn3Nt5JX8ZnJYtW3ptH374oZf//e9/z7nfrl27erk7b1LV0uaLFi1K4g4d/FUramgOhaVXytS4cWMvv+qqq5L4kksuyftad9mm8GmOxbj++uuT+JZbbvHawvfEiFh6BQBQORQUAEAUFBQAQBS1bg4l7ZYuXZrE4dL24XX9WJhDqZr7+YBwCfLw8yP5lhXPN09S1RyK+2hhdzmPGsQcSpG22GILLx8/fryXH3/88QXvy13uafbs2V5b+FkTd7mnkDvu3EcFS/7S+5L01VdfFdy/KjCHAgCoHAoKACCK1C+9Utu4lzLC1WdRcyZOnJjEM2bM8Nrcy2GS1K9fv4L327Zt2yRu1KhR3m1XrlxZ8H6RTuHHD4pZUfyzzz7zcvcW4wULFnht4XIqQ4YMSeIjjjgi5zFOPvlkL//FL37h5c8//3whXS0ZZygAgCgoKACAKCgoAIAouG04ss6dOyexe91ekpo1a1aRY3LbcM2ZNWtWEv/sZz/z2sJ/W82bN0/ilCxXz23DG9GqVSsvHzBgQBJfeOGFXlsx758HHXSQl4dLM+XTpEmTJA7fV9x5nLA/4RL6ffr0KfiYVeC2YQBA5VBQAABRUFAAAFHwORSgCOFnTRo2bJjE4dIroZTMm6AKo0eP9vJDDz20pP2MGzfOy8PHJBTjyy+/TOKpU6d6bfk+C1Op5Z5y4QwFABAFBQUAEAWXvIAiuEutSP5KxeEtm8XcUoqa4z7xUPrpciWucAXx8EmL77zzThKHtxjneypjuIpxeAv6aaedlsQdO3bM2afwGLfffnvOY1YCZygAgCgoKACAKCgoAIAoUj+H8uqrr3q5uwT4Lbfc4rWFy5LXBPcWvqpuI0Xt07VrVy93f8fh75vbhGuHcE4i39xXOEcRbuu+P9166615j+uOl+22285rK+bJj26f1q5d67VFfEJjQThDAQBEQUEBAESR+kteDz/8sJcPHz48icOnnL399ttevmbNmor1a4P69et7+eWXX57Es2fPrvjxUb3c24Sl/JdHJkyYUOnuIGWOOuqogrd1L3mVc4u5e6ty//79vbbXX3+95P2WgjMUAEAUFBQAQBQUFABAFKmfQ7nvvvu8vGXLlkk8cOBAr23nnXf28qFDhybxzJkzK9A7f85EkrbddtskDm9rRt3jXgcPl+W4//77q7s7KMHcuXO9vFu3bjXUk9zcFYanTZvmtY0dOzaJly9fXm192hjOUAAAUVBQAABRUFAAAFGkfg4ldMUVVyRx+DmUQYMGeXmPHj2S+PPPP/faxo8fn/MY77//vpeHyyL07NkziTt37uy1uU9pe/HFF3MeA3WD+/mBfMuTI73c9xRJmjJlipdPnjw5icMndhbj3nvv9XJ3mZRnnnnGawvndVatWpXEP/zwQ8l9qDTOUAAAUVBQAABRmGI+8m+MSfUj6HbddVcvd1eGdS9TSf4T0Ir16KOPJnG4vMbEiRNL3m+prLWpXtY47eOmGOGl0lNOOSWJw6V/OnToUC19KsOb1tpUd7IujZ06ZqNjhzMUAEAUFBQAQBQUFABAFLXutuF8Fi9e7OVjxozZaAyUKpxzdPPwVoE0o6YAAACWSURBVE9gU8MZCgAgCgoKACAKCgoAIIo69TmUTRWfQ0GJ+BwKSsXnUAAAlUNBAQBEQUEBAERBQQEAREFBAQBEQUEBAERR7NIrKyR9VomOoGQta7oDBWDcpBNjB6Xa6Ngp6nMoAADkwiUvAEAUFBQAQBQUFABAFBQUAEAUFBQAQBQUFABAFBQUAEAUFBQAQBQUFABAFP8PSTV0oIqXHNoAAAAASUVORK5CYII=\n" }, "metadata": {} } ] }, { "cell_type": "markdown", "source": [ "Define the network. This is a more typical way to define a network than the sequential structure. We define a class for the network, and define the parameters in the constructor. Then we use a function called forward to actually run the network. " ], "metadata": { "id": "_sFvRDGrl4qe" } }, { "cell_type": "code", "source": [ "from os import X_OK\n", "# TODO Change this class to implement\n", "# 1. A valid convolution with kernel size 5, 1 input channel and 10 output channels\n", "# 2. A max pooling operation over a 2x2 area\n", "# 3. A Relu\n", "# 4. A valid convolution with kernel size 5, 10 input channels and 20 output channels\n", "# 5. A 2D Dropout layer\n", "# 6. A max pooling operation over a 2x2 area\n", "# 7. A relu\n", "# 8. A flattening operation\n", "# 9. A fully connected layer mapping from (whatever dimensions we are at-- find out using .shape) to 50 \n", "# 10. A ReLU\n", "# 11. A fully connected layer mapping from 50 to 10 dimensions\n", "# 12. A softmax function.\n", "\n", "# Replace this class which implements a minimal network (which still does okay)\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " # Valid convolution, 1 channel in, 3 channels out, stride 1, kernel size = 5\n", " self.conv1 = nn.Conv2d(1, 3, kernel_size=5)\n", " # Dropout for convolutions\n", " self.drop = nn.Dropout2d()\n", " # Fully connected layer\n", " self.fc1 = nn.Linear(432, 10)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.drop(x)\n", " x = F.max_pool2d(x,2)\n", " x = F.relu(x)\n", " x = x.flatten(1)\n", " x = self.fc1(x)\n", " x = F.log_softmax(x)\n", " return x" ], "metadata": { "id": "EQkvw2KOPVl7" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "# He initialization of weights\n", "def weights_init(layer_in):\n", " if isinstance(layer_in, nn.Linear):\n", " nn.init.kaiming_uniform_(layer_in.weight)\n", " layer_in.bias.data.fill_(0.0)" ], "metadata": { "id": "qWZtkCZcU_dg" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "# Create network\n", "model = Net()\n", "# Initialize model weights\n", "model.apply(weights_init)\n", "# Define optimizer\n", "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n" ], "metadata": { "id": "FslroPJJffrh" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "def train(epoch):\n", " model.train()\n", " # Get each \n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", " # Store results\n", " if batch_idx % 10 == 0:\n", " print('Train Epoch: {} [{}/{}]\\tLoss: {:.6f}'.format(\n", " epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()))" ], "metadata": { "id": "xKQd9PzkQ766" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "def test():\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", " with torch.no_grad():\n", " for data, target in test_loader:\n", " output = model(data)\n", " test_loss += F.nll_loss(output, target, size_average=False).item()\n", " pred = output.data.max(1, keepdim=True)[1]\n", " correct += pred.eq(target.data.view_as(pred)).sum()\n", " test_loss /= len(test_loader.dataset)\n", " print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", " test_loss, correct, len(test_loader.dataset),\n", " 100. * correct / len(test_loader.dataset)))" ], "metadata": { "id": "Byn-f7qWRLxX" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "# Get initial performance\n", "test()\n", "# Train for three epochs\n", "n_epochs = 3\n", "for epoch in range(1, n_epochs + 1):\n", " train(epoch)\n", " test()" ], "metadata": { "id": "YgLaex1pfhqz", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "96b4a0e4-94b3-4eba-ea1a-53f6f5ee8792" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":34: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", " x = F.log_softmax(x)\n", "/usr/local/lib/python3.8/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", " warnings.warn(warning.format(ret))\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", "Test set: Avg. loss: 2.6743, Accuracy: 593/10000 (6%)\n", "\n", "Train Epoch: 1 [0/60000]\tLoss: 3.022157\n", "Train Epoch: 1 [640/60000]\tLoss: 2.232255\n", "Train Epoch: 1 [1280/60000]\tLoss: 2.339070\n", "Train Epoch: 1 [1920/60000]\tLoss: 2.123211\n", "Train Epoch: 1 [2560/60000]\tLoss: 1.977328\n", "Train Epoch: 1 [3200/60000]\tLoss: 1.794155\n", "Train Epoch: 1 [3840/60000]\tLoss: 1.607425\n", "Train Epoch: 1 [4480/60000]\tLoss: 1.265900\n", "Train Epoch: 1 [5120/60000]\tLoss: 1.380399\n", "Train Epoch: 1 [5760/60000]\tLoss: 1.443841\n", "Train Epoch: 1 [6400/60000]\tLoss: 0.962643\n", "Train Epoch: 1 [7040/60000]\tLoss: 1.050701\n", "Train Epoch: 1 [7680/60000]\tLoss: 0.882971\n", "Train Epoch: 1 [8320/60000]\tLoss: 1.056942\n", "Train Epoch: 1 [8960/60000]\tLoss: 1.204123\n", "Train Epoch: 1 [9600/60000]\tLoss: 0.875873\n", "Train Epoch: 1 [10240/60000]\tLoss: 0.952260\n", "Train Epoch: 1 [10880/60000]\tLoss: 0.751124\n", "Train Epoch: 1 [11520/60000]\tLoss: 0.867441\n", "Train Epoch: 1 [12160/60000]\tLoss: 0.956779\n", "Train Epoch: 1 [12800/60000]\tLoss: 0.679759\n", "Train Epoch: 1 [13440/60000]\tLoss: 0.949922\n", "Train Epoch: 1 [14080/60000]\tLoss: 0.859521\n", "Train Epoch: 1 [14720/60000]\tLoss: 0.948663\n", "Train Epoch: 1 [15360/60000]\tLoss: 0.584370\n", "Train Epoch: 1 [16000/60000]\tLoss: 0.865371\n", "Train Epoch: 1 [16640/60000]\tLoss: 0.726936\n", "Train Epoch: 1 [17280/60000]\tLoss: 0.598072\n", "Train Epoch: 1 [17920/60000]\tLoss: 0.849646\n", "Train Epoch: 1 [18560/60000]\tLoss: 0.793431\n", "Train Epoch: 1 [19200/60000]\tLoss: 0.720634\n", "Train Epoch: 1 [19840/60000]\tLoss: 0.770133\n", "Train Epoch: 1 [20480/60000]\tLoss: 0.888984\n", "Train Epoch: 1 [21120/60000]\tLoss: 0.595875\n", "Train Epoch: 1 [21760/60000]\tLoss: 0.750772\n", "Train Epoch: 1 [22400/60000]\tLoss: 0.829312\n", "Train Epoch: 1 [23040/60000]\tLoss: 0.896843\n", "Train Epoch: 1 [23680/60000]\tLoss: 0.762636\n", "Train Epoch: 1 [24320/60000]\tLoss: 0.806846\n", "Train Epoch: 1 [24960/60000]\tLoss: 0.767578\n", "Train Epoch: 1 [25600/60000]\tLoss: 0.746623\n", "Train Epoch: 1 [26240/60000]\tLoss: 0.742440\n", "Train Epoch: 1 [26880/60000]\tLoss: 1.142117\n", "Train Epoch: 1 [27520/60000]\tLoss: 0.772795\n", "Train Epoch: 1 [28160/60000]\tLoss: 0.555436\n", "Train Epoch: 1 [28800/60000]\tLoss: 0.763470\n", "Train Epoch: 1 [29440/60000]\tLoss: 0.632003\n", "Train Epoch: 1 [30080/60000]\tLoss: 0.836483\n", "Train Epoch: 1 [30720/60000]\tLoss: 0.704554\n", "Train Epoch: 1 [31360/60000]\tLoss: 0.862588\n", "Train Epoch: 1 [32000/60000]\tLoss: 0.582613\n", "Train Epoch: 1 [32640/60000]\tLoss: 0.784028\n", "Train Epoch: 1 [33280/60000]\tLoss: 0.758522\n", "Train Epoch: 1 [33920/60000]\tLoss: 0.778791\n", "Train Epoch: 1 [34560/60000]\tLoss: 0.849997\n", "Train Epoch: 1 [35200/60000]\tLoss: 0.836082\n", "Train Epoch: 1 [35840/60000]\tLoss: 0.448398\n", "Train Epoch: 1 [36480/60000]\tLoss: 0.729314\n", "Train Epoch: 1 [37120/60000]\tLoss: 0.811088\n", "Train Epoch: 1 [37760/60000]\tLoss: 0.592963\n", "Train Epoch: 1 [38400/60000]\tLoss: 0.642293\n", "Train Epoch: 1 [39040/60000]\tLoss: 0.784302\n", "Train Epoch: 1 [39680/60000]\tLoss: 0.694944\n", "Train Epoch: 1 [40320/60000]\tLoss: 0.720275\n", "Train Epoch: 1 [40960/60000]\tLoss: 0.536233\n", "Train Epoch: 1 [41600/60000]\tLoss: 0.715839\n", "Train Epoch: 1 [42240/60000]\tLoss: 0.557930\n", "Train Epoch: 1 [42880/60000]\tLoss: 0.652230\n", "Train Epoch: 1 [43520/60000]\tLoss: 0.686960\n", "Train Epoch: 1 [44160/60000]\tLoss: 0.562727\n", "Train Epoch: 1 [44800/60000]\tLoss: 0.728505\n", "Train Epoch: 1 [45440/60000]\tLoss: 0.874868\n", "Train Epoch: 1 [46080/60000]\tLoss: 0.713117\n", "Train Epoch: 1 [46720/60000]\tLoss: 0.727794\n", "Train Epoch: 1 [47360/60000]\tLoss: 0.747727\n", "Train Epoch: 1 [48000/60000]\tLoss: 0.631520\n", "Train Epoch: 1 [48640/60000]\tLoss: 0.515534\n", "Train Epoch: 1 [49280/60000]\tLoss: 0.695285\n", "Train Epoch: 1 [49920/60000]\tLoss: 0.690564\n", "Train Epoch: 1 [50560/60000]\tLoss: 0.696663\n", "Train Epoch: 1 [51200/60000]\tLoss: 0.637634\n", "Train Epoch: 1 [51840/60000]\tLoss: 0.722715\n", "Train Epoch: 1 [52480/60000]\tLoss: 0.832013\n", "Train Epoch: 1 [53120/60000]\tLoss: 0.594781\n", "Train Epoch: 1 [53760/60000]\tLoss: 0.613957\n", "Train Epoch: 1 [54400/60000]\tLoss: 0.836092\n", "Train Epoch: 1 [55040/60000]\tLoss: 0.635827\n", "Train Epoch: 1 [55680/60000]\tLoss: 0.623362\n", "Train Epoch: 1 [56320/60000]\tLoss: 0.540473\n", "Train Epoch: 1 [56960/60000]\tLoss: 0.780923\n", "Train Epoch: 1 [57600/60000]\tLoss: 0.476055\n", "Train Epoch: 1 [58240/60000]\tLoss: 0.905469\n", "Train Epoch: 1 [58880/60000]\tLoss: 0.700290\n", "Train Epoch: 1 [59520/60000]\tLoss: 0.500782\n", "\n", "Test set: Avg. loss: 0.2170, Accuracy: 9355/10000 (94%)\n", "\n", "Train Epoch: 2 [0/60000]\tLoss: 0.530213\n", "Train Epoch: 2 [640/60000]\tLoss: 0.954322\n", "Train Epoch: 2 [1280/60000]\tLoss: 0.627641\n", "Train Epoch: 2 [1920/60000]\tLoss: 0.694282\n", "Train Epoch: 2 [2560/60000]\tLoss: 0.490609\n", "Train Epoch: 2 [3200/60000]\tLoss: 0.518218\n", "Train Epoch: 2 [3840/60000]\tLoss: 0.511994\n", "Train Epoch: 2 [4480/60000]\tLoss: 0.575610\n", "Train Epoch: 2 [5120/60000]\tLoss: 0.760527\n", "Train Epoch: 2 [5760/60000]\tLoss: 0.618076\n", "Train Epoch: 2 [6400/60000]\tLoss: 0.551507\n", "Train Epoch: 2 [7040/60000]\tLoss: 0.661573\n", "Train Epoch: 2 [7680/60000]\tLoss: 0.503254\n", "Train Epoch: 2 [8320/60000]\tLoss: 0.611196\n", "Train Epoch: 2 [8960/60000]\tLoss: 0.568107\n", "Train Epoch: 2 [9600/60000]\tLoss: 0.680320\n", "Train Epoch: 2 [10240/60000]\tLoss: 0.749674\n", "Train Epoch: 2 [10880/60000]\tLoss: 0.766421\n", "Train Epoch: 2 [11520/60000]\tLoss: 0.720416\n", "Train Epoch: 2 [12160/60000]\tLoss: 0.552917\n", "Train Epoch: 2 [12800/60000]\tLoss: 0.642536\n", "Train Epoch: 2 [13440/60000]\tLoss: 0.564653\n", "Train Epoch: 2 [14080/60000]\tLoss: 0.562467\n", "Train Epoch: 2 [14720/60000]\tLoss: 0.683435\n", "Train Epoch: 2 [15360/60000]\tLoss: 0.638271\n", "Train Epoch: 2 [16000/60000]\tLoss: 0.667720\n", "Train Epoch: 2 [16640/60000]\tLoss: 0.417489\n", "Train Epoch: 2 [17280/60000]\tLoss: 0.661206\n", "Train Epoch: 2 [17920/60000]\tLoss: 0.586723\n", "Train Epoch: 2 [18560/60000]\tLoss: 0.577134\n", "Train Epoch: 2 [19200/60000]\tLoss: 0.882659\n", "Train Epoch: 2 [19840/60000]\tLoss: 0.705308\n", "Train Epoch: 2 [20480/60000]\tLoss: 0.621367\n", "Train Epoch: 2 [21120/60000]\tLoss: 0.451295\n", "Train Epoch: 2 [21760/60000]\tLoss: 0.589745\n", "Train Epoch: 2 [22400/60000]\tLoss: 0.653456\n", "Train Epoch: 2 [23040/60000]\tLoss: 0.404559\n", "Train Epoch: 2 [23680/60000]\tLoss: 0.613846\n", "Train Epoch: 2 [24320/60000]\tLoss: 0.720263\n", "Train Epoch: 2 [24960/60000]\tLoss: 0.446476\n", "Train Epoch: 2 [25600/60000]\tLoss: 0.905395\n", "Train Epoch: 2 [26240/60000]\tLoss: 0.574859\n", "Train Epoch: 2 [26880/60000]\tLoss: 0.779760\n", "Train Epoch: 2 [27520/60000]\tLoss: 0.447516\n", "Train Epoch: 2 [28160/60000]\tLoss: 0.553814\n", "Train Epoch: 2 [28800/60000]\tLoss: 0.724654\n", "Train Epoch: 2 [29440/60000]\tLoss: 0.451007\n", "Train Epoch: 2 [30080/60000]\tLoss: 0.357663\n", "Train Epoch: 2 [30720/60000]\tLoss: 0.534665\n", "Train Epoch: 2 [31360/60000]\tLoss: 0.912386\n", "Train Epoch: 2 [32000/60000]\tLoss: 0.635334\n", "Train Epoch: 2 [32640/60000]\tLoss: 0.611335\n", "Train Epoch: 2 [33280/60000]\tLoss: 0.498800\n", "Train Epoch: 2 [33920/60000]\tLoss: 0.726310\n", "Train Epoch: 2 [34560/60000]\tLoss: 0.618861\n", "Train Epoch: 2 [35200/60000]\tLoss: 0.498235\n", "Train Epoch: 2 [35840/60000]\tLoss: 0.556707\n", "Train Epoch: 2 [36480/60000]\tLoss: 0.828103\n", "Train Epoch: 2 [37120/60000]\tLoss: 0.459869\n", "Train Epoch: 2 [37760/60000]\tLoss: 0.699695\n", "Train Epoch: 2 [38400/60000]\tLoss: 0.746511\n", "Train Epoch: 2 [39040/60000]\tLoss: 0.620254\n", "Train Epoch: 2 [39680/60000]\tLoss: 0.685517\n", "Train Epoch: 2 [40320/60000]\tLoss: 0.444510\n", "Train Epoch: 2 [40960/60000]\tLoss: 0.607820\n", "Train Epoch: 2 [41600/60000]\tLoss: 0.453002\n", "Train Epoch: 2 [42240/60000]\tLoss: 0.575601\n", "Train Epoch: 2 [42880/60000]\tLoss: 0.521206\n", "Train Epoch: 2 [43520/60000]\tLoss: 0.505593\n", "Train Epoch: 2 [44160/60000]\tLoss: 0.494645\n", "Train Epoch: 2 [44800/60000]\tLoss: 0.445350\n", "Train Epoch: 2 [45440/60000]\tLoss: 1.022786\n", "Train Epoch: 2 [46080/60000]\tLoss: 0.934101\n", "Train Epoch: 2 [46720/60000]\tLoss: 0.581446\n", "Train Epoch: 2 [47360/60000]\tLoss: 0.565760\n", "Train Epoch: 2 [48000/60000]\tLoss: 0.418244\n", "Train Epoch: 2 [48640/60000]\tLoss: 1.008578\n", "Train Epoch: 2 [49280/60000]\tLoss: 0.604322\n", "Train Epoch: 2 [49920/60000]\tLoss: 0.721556\n", "Train Epoch: 2 [50560/60000]\tLoss: 0.521967\n", "Train Epoch: 2 [51200/60000]\tLoss: 0.410529\n", "Train Epoch: 2 [51840/60000]\tLoss: 0.719665\n", "Train Epoch: 2 [52480/60000]\tLoss: 0.718958\n", "Train Epoch: 2 [53120/60000]\tLoss: 0.678785\n", "Train Epoch: 2 [53760/60000]\tLoss: 0.497077\n", "Train Epoch: 2 [54400/60000]\tLoss: 0.617133\n", "Train Epoch: 2 [55040/60000]\tLoss: 0.468108\n", "Train Epoch: 2 [55680/60000]\tLoss: 0.576436\n", "Train Epoch: 2 [56320/60000]\tLoss: 0.433144\n", "Train Epoch: 2 [56960/60000]\tLoss: 0.715489\n", "Train Epoch: 2 [57600/60000]\tLoss: 0.386602\n", "Train Epoch: 2 [58240/60000]\tLoss: 0.258836\n", "Train Epoch: 2 [58880/60000]\tLoss: 0.300112\n", "Train Epoch: 2 [59520/60000]\tLoss: 0.549713\n", "\n", "Test set: Avg. loss: 0.1792, Accuracy: 9485/10000 (95%)\n", "\n", "Train Epoch: 3 [0/60000]\tLoss: 0.411068\n", "Train Epoch: 3 [640/60000]\tLoss: 0.617724\n", "Train Epoch: 3 [1280/60000]\tLoss: 0.573067\n", "Train Epoch: 3 [1920/60000]\tLoss: 0.783076\n", "Train Epoch: 3 [2560/60000]\tLoss: 0.673770\n", "Train Epoch: 3 [3200/60000]\tLoss: 0.554354\n", "Train Epoch: 3 [3840/60000]\tLoss: 0.984257\n", "Train Epoch: 3 [4480/60000]\tLoss: 0.530157\n", "Train Epoch: 3 [5120/60000]\tLoss: 0.612170\n", "Train Epoch: 3 [5760/60000]\tLoss: 0.679237\n", "Train Epoch: 3 [6400/60000]\tLoss: 0.645367\n", "Train Epoch: 3 [7040/60000]\tLoss: 0.436184\n", "Train Epoch: 3 [7680/60000]\tLoss: 0.745752\n", "Train Epoch: 3 [8320/60000]\tLoss: 0.474170\n", "Train Epoch: 3 [8960/60000]\tLoss: 0.337120\n", "Train Epoch: 3 [9600/60000]\tLoss: 0.746035\n", "Train Epoch: 3 [10240/60000]\tLoss: 0.606994\n", "Train Epoch: 3 [10880/60000]\tLoss: 0.590475\n", "Train Epoch: 3 [11520/60000]\tLoss: 0.506705\n", "Train Epoch: 3 [12160/60000]\tLoss: 0.531224\n", "Train Epoch: 3 [12800/60000]\tLoss: 0.573767\n", "Train Epoch: 3 [13440/60000]\tLoss: 0.489704\n", "Train Epoch: 3 [14080/60000]\tLoss: 0.381763\n", "Train Epoch: 3 [14720/60000]\tLoss: 0.788660\n", "Train Epoch: 3 [15360/60000]\tLoss: 0.398151\n", "Train Epoch: 3 [16000/60000]\tLoss: 0.673685\n", "Train Epoch: 3 [16640/60000]\tLoss: 0.442040\n", "Train Epoch: 3 [17280/60000]\tLoss: 0.400728\n", "Train Epoch: 3 [17920/60000]\tLoss: 0.665893\n", "Train Epoch: 3 [18560/60000]\tLoss: 0.680546\n", "Train Epoch: 3 [19200/60000]\tLoss: 0.699877\n", "Train Epoch: 3 [19840/60000]\tLoss: 0.656768\n", "Train Epoch: 3 [20480/60000]\tLoss: 0.564827\n", "Train Epoch: 3 [21120/60000]\tLoss: 0.575950\n", "Train Epoch: 3 [21760/60000]\tLoss: 0.414778\n", "Train Epoch: 3 [22400/60000]\tLoss: 0.561424\n", "Train Epoch: 3 [23040/60000]\tLoss: 0.688791\n", "Train Epoch: 3 [23680/60000]\tLoss: 0.567774\n", "Train Epoch: 3 [24320/60000]\tLoss: 0.525123\n", "Train Epoch: 3 [24960/60000]\tLoss: 0.697168\n", "Train Epoch: 3 [25600/60000]\tLoss: 0.623199\n", "Train Epoch: 3 [26240/60000]\tLoss: 0.404254\n", "Train Epoch: 3 [26880/60000]\tLoss: 0.565381\n", "Train Epoch: 3 [27520/60000]\tLoss: 0.609798\n", "Train Epoch: 3 [28160/60000]\tLoss: 0.753811\n", "Train Epoch: 3 [28800/60000]\tLoss: 0.389456\n", "Train Epoch: 3 [29440/60000]\tLoss: 0.436780\n", "Train Epoch: 3 [30080/60000]\tLoss: 0.636898\n", "Train Epoch: 3 [30720/60000]\tLoss: 0.599510\n", "Train Epoch: 3 [31360/60000]\tLoss: 0.629284\n", "Train Epoch: 3 [32000/60000]\tLoss: 0.505906\n", "Train Epoch: 3 [32640/60000]\tLoss: 0.741626\n", "Train Epoch: 3 [33280/60000]\tLoss: 0.642386\n", "Train Epoch: 3 [33920/60000]\tLoss: 0.635572\n", "Train Epoch: 3 [34560/60000]\tLoss: 0.639967\n", "Train Epoch: 3 [35200/60000]\tLoss: 0.720257\n", "Train Epoch: 3 [35840/60000]\tLoss: 0.478147\n", "Train Epoch: 3 [36480/60000]\tLoss: 0.685794\n", "Train Epoch: 3 [37120/60000]\tLoss: 0.421208\n", "Train Epoch: 3 [37760/60000]\tLoss: 0.584400\n", "Train Epoch: 3 [38400/60000]\tLoss: 0.702878\n", "Train Epoch: 3 [39040/60000]\tLoss: 0.667306\n", "Train Epoch: 3 [39680/60000]\tLoss: 0.369802\n", "Train Epoch: 3 [40320/60000]\tLoss: 0.655229\n", "Train Epoch: 3 [40960/60000]\tLoss: 0.490676\n", "Train Epoch: 3 [41600/60000]\tLoss: 0.446377\n", "Train Epoch: 3 [42240/60000]\tLoss: 0.393590\n", "Train Epoch: 3 [42880/60000]\tLoss: 0.618019\n", "Train Epoch: 3 [43520/60000]\tLoss: 0.411980\n", "Train Epoch: 3 [44160/60000]\tLoss: 0.757768\n", "Train Epoch: 3 [44800/60000]\tLoss: 0.506138\n", "Train Epoch: 3 [45440/60000]\tLoss: 0.457559\n", "Train Epoch: 3 [46080/60000]\tLoss: 0.427676\n", "Train Epoch: 3 [46720/60000]\tLoss: 0.525319\n", "Train Epoch: 3 [47360/60000]\tLoss: 0.454945\n", "Train Epoch: 3 [48000/60000]\tLoss: 0.300189\n", "Train Epoch: 3 [48640/60000]\tLoss: 0.571119\n", "Train Epoch: 3 [49280/60000]\tLoss: 0.796717\n", "Train Epoch: 3 [49920/60000]\tLoss: 0.410930\n", "Train Epoch: 3 [50560/60000]\tLoss: 0.679963\n", "Train Epoch: 3 [51200/60000]\tLoss: 0.625742\n", "Train Epoch: 3 [51840/60000]\tLoss: 0.506195\n", "Train Epoch: 3 [52480/60000]\tLoss: 0.527920\n", "Train Epoch: 3 [53120/60000]\tLoss: 0.477663\n", "Train Epoch: 3 [53760/60000]\tLoss: 0.429325\n", "Train Epoch: 3 [54400/60000]\tLoss: 0.561766\n", "Train Epoch: 3 [55040/60000]\tLoss: 0.622166\n", "Train Epoch: 3 [55680/60000]\tLoss: 0.454855\n", "Train Epoch: 3 [56320/60000]\tLoss: 0.564210\n", "Train Epoch: 3 [56960/60000]\tLoss: 0.408687\n", "Train Epoch: 3 [57600/60000]\tLoss: 0.639709\n", "Train Epoch: 3 [58240/60000]\tLoss: 0.516085\n", "Train Epoch: 3 [58880/60000]\tLoss: 0.692927\n", "Train Epoch: 3 [59520/60000]\tLoss: 0.301529\n", "\n", "Test set: Avg. loss: 0.1659, Accuracy: 9509/10000 (95%)\n", "\n" ] } ] }, { "cell_type": "code", "source": [ "# Run network on data we got before and show predictions\n", "output = model(example_data)\n", "\n", "fig = plt.figure()\n", "for i in range(6):\n", " plt.subplot(2,3,i+1)\n", " plt.tight_layout()\n", " plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n", " plt.title(\"Prediction: {}\".format(\n", " output.data.max(1, keepdim=True)[1][i].item()))\n", " plt.xticks([])\n", " plt.yticks([])\n", "plt.show()" ], "metadata": { "id": "o7fRUAy9Se1B", "colab": { "base_uri": "https://localhost:8080/", "height": 320 }, "outputId": "de0d09a1-092d-4653-b6f6-c2d191c33441" }, "execution_count": 10, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ ":34: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", " x = F.log_softmax(x)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deZBU1d3/8c/XILIp4BaFCIpGcAE3iAYE96hAFFFCrCdBEqWC4RF/iUm5mxBxXxCLKI9KJGJEjCwqxkdFxQriBmpURCOFLIJB1pQs5gl6fn90e3PPcbqnu+d0T8/M+1VF1fc75/a9Z3oO8517z+1zzTknAADqaof67gAAoHGgoAAAoqCgAACioKAAAKKgoAAAoqCgAACiaPAFxcwmm9nYbNzXzD4ocT8TzezquL1DNWPsoBSMm9wqUlDMbJmZbTOzzWa2JvsDaRP7OM65vzrnuhbQn+FmNi947Ujn3LWx+1TDsX9oZh+Y2T/N7FMz+6OZ7VLu4zZUjJ2vHb+Lmc02s8/MbJ2Z3VyJ4zY0jBvv2DuZ2TgzW21mG83sLjPbsRzHquQZyvedc20kHSmpp6Srwg3MrFkF+1NfXpLUxznXVlIXSc0kja3fLlU9xo4kM2su6VlJz0vaS9K3JD1Yr52qboybjMuU+f4PlXSgMu/H196LGCp+ycs5t0rSU8p8czIzZ2ajzOxDSR9mvzbQzN4ys01mNt/Menz1ejM7wszeyP6FNk1Si1Tb8Wb2cSrfx8xmmNlaM1tvZhPM7CBJEyV9N/vXy6bstslpbDYfYWZLzGyDmT1uZh1Sbc7MRprZh9k+/t7MrMDvf6Vzbl3qS19IOqCY97CpaupjR9JwSaudc7c757Y45z53zr1d9BvZxDBu9H1JdzrnNjjn1kq6U9JPi30fC1HxgmJm+0jqL+nN1JcHSTpa0sFmdoSkP0j6maTdJP2PpMezp23NJc2SNEXSrpL+LOnsHMf5hqTZkpZL2ldSR0kPO+cWSxop6WXnXBvnXLsaXnuipBsk/UDS3tl9PBxsNlBSL0k9studmn1tp+wPvFOe9+BYM/unpM+y/b8j17b4D8aOjpG0zMyesszlrrlm1j3Htshi3GQOEcTfMrO2ebYvjXOu7P8kLZO0WdImZd6ouyS1zLY5SSemtr1b0rXB6z+QdJykfpJWS7JU23xJY7Px8ZI+zsbflbRWUrMa+jNc0rzga5NT+5kk6eZUWxtJ/5a0b6rPx6baH5F0WQnvS0dJv5V0YCV+Dg3xH2PHO84z2X2dLqm5pF9LWiqpeX3/nKrtH+PGO85YZS6176HMpdJXs/vbO/b7Xsnrh4Occ3NytK1MxZ0lnWdmF6W+1lxSB2XehFUu+y5lLc+xz30kLXfObS+hrx0kvfFV4pzbbGbrlSkAy7Jf/kdq+63KDICiOOdWmdn/KvOXyJEl9LOpYOxkbFPml9JTkmRmtypzLfwgSX8roa+NHeMm4zpJ7SS9Jelfku6VdISkNSX0M69quW04/cNaKek651y71L9Wzrmpkj6R1DG4dpjrNG+lpE5W86RbbUssr1ZmkEmSzKy1MqfCq2r7RkrQTNL+ZdhvU9GUxs7bBRwfhWky48Y5t80599/OuY7OuS6S1kta6Jz7sq77DlVLQUm7V9JIMzvaMlqb2QAz21nSy5K2SxptZjua2WBJ38mxn9eUGQw3ZvfRwsz6ZNvWKHMNsXmO106V9BMzO9zMdpJ0vaRXnXPL6vrNmdl/fXWt08w6K/PXw3N13S8kNfKxo8wdXceY2cnZ6/X/T9I6SYsj7Lspa9Tjxsw6mlmH7Pd2jKSrJf2mrvutSdUVFOfcAkkjJE2QtFHSEmWuP8o593+SBmfzDZKGSpqRYz9fKHN3wwGSVkj6OLu9lLntcpGkf5jZuhpeO0eZN326MgNkf0k/LKT/2QmyzXkmyA6WNN/MtihzXfOD7PeLOmrsY8c594GkHylzx9BGSWdKOiP7vaFEjX3cZPc1X9IWSX9UZu7lmUL2XSzzLw0CAFCaqjtDAQA0TBQUAEAUFBQAQBQUFABAFBQUAEAURX1S3sy4JawKOecKXSSuXjBuqtY659we9d2JfBg7VavGscMZCtB05VpCBKhNjWOHggIAiIKCAgCIgoICAIiCggIAiIKCAgCIgoICAIiCggIAiIKCAgCIopLPlAcavOOOO87Lx40bl8TPPec/ePPXv/51RfoEVAvOUAAAUVBQAABRUFAAAFE0mTmUffbZx8uvuOIKLx85cmQST5gwwWu76KKLytcxNCgDBw708q5duybxs88+W+nuAFWFMxQAQBQUFABAFI3qktd5553n5VdeeWUSd+7c2Wtr1sz/1r/88sskvvDCC722HXbw6+6oUaPq1E80HOG4GTZsmJePHz8+icPLqEBTwxkKACAKCgoAIAoKCgAgigY3h5KeJ7nqqqu8tn333dfLH3300SS+++67vbZ7773Xyw899NAkfumll7y2U089taS+ouFL304uSa1atfJy5k2Qlh4fQ4cO9douu+wyL9+2bVsSz5w502u7/vrrvfzf//53rC6WFWcoAIAoKCgAgCiq7pJXt27dvPyEE07w8muuuSaJt27d6rXdc889Xn7bbbcl8dKlS/MeN9xXWnjr6LnnnpvEU6dOzbtfNDzf+c53kji8FX3lypWV7g6q2E477eTlf/jDH5J4yJAhXpuZeblzLom7d+/ute22225ePnr06Dr1s1I4QwEAREFBAQBEQUEBAERRFXMo77zzThJ37NjRa2vbtq2Xf/TRR0l8xx13eG3hKsHF2LhxYxIvW7bMawtvR27RokXJx0H1admypZc/9thjSbzHHnt4bXPnzq1El1ClDjjgAC+/7rrrvPycc87J+dpJkyZ5+U9/+tOc2377298uoXf1jzMUAEAUFBQAQBQUFABAFBWbQznkkEOSePr06V7b/vvvn8ThUvHpORPJf2Le+++/H61/mzdvTuK1a9d6beEcChqXcHmV9LzJe++957X96le/Kvk4zZs3T+JOnTp5bemlf2bNmlXyMRDf9773vSQOf3eFS/GkzZkzx8snT57s5fnmUG666aYielg9OEMBAERBQQEARFGxS17p2+32228/ry19mSs8pbz44ou9/JNPPilD76TTTjstiXv16lWWY6A6tGvXzsvPPPPMnNuG42/16tUFH+e3v/2tl59xxhlJ3KNHD68tvYQPl7zqV/qyuuSvTB5e4gpvBU5f1nr33Xe9tnB5lXzC5Z4aCs5QAABRUFAAAFFQUAAAUVRsDiW9nMWoUaO8tvRtmlOmTPHayjVnEi6RED7REY3X6aef7uXHHnusl3/22WdJvH79+rz7at26dRKPGzfOawuf2LdixYoknjdvntd20kknJfFPfvITr+3+++/P2wfEFc517LXXXkn8y1/+0msLf+b5LFmyxMvXrFlT4zEkac899yx4v9WEMxQAQBQUFABAFBQUAEAU9bJ8/X333VeR46QfoxnOmdx4441evvPOOxe833AJazQshx12WN729Hzf22+/nXfb9LIc4VIaixYt8vL0PEk4htJzKl26dMl7TJRX//79vTz92OcHH3yw5P2Gc3ft27dP4vTjgBsyzlAAAFFQUAAAUVTFExtLFT5Nb/DgwV6evj05vdpxXaVvHXz99de9NpbNqE7pFaOHDRvmtW3dutXLhw8fnsTp24Il6aGHHvLy9HIq4X7Cp/mtW7cuicNLXmaWo+eob+nbyLdt21bw68LlU8KleNKrTzcWnKEAAKKgoAAAoqCgAACiaNBzKL179/byu+66q+R9pZ/MFy6hn++JfuGSCahO559/fhKHy1osXrzYy9u0aZPE4bInAwYM8PL0vMkFF1zgtT3yyCM5+xPuJ33b6OzZs3O+DuW3adMmL+/Tp08SH3zwwV7ba6+95uXpubGnnnrKawuf0tlYbhVO4wwFABAFBQUAEAUFBQAQRYOeQwmvdS5YsMDLu3XrlsQvv/yy1xYuO/3mm28m8aeffuq1zZ8/38vT10Z//OMfe20TJ06srduogJYtW3r5ySefnHPbcM7igQceSOL050ykry+nMmTIkCT++9//nrdP/fr1S+Kf//znXtvSpUuTOD2fh8qbMWOGl6fnu8KlV2644QYvv+KKK5I4XEInvVy9JO20005J3LZt29I6W2U4QwEAREFBAQBE0aAveb344otefvTRR3t5+kl877//vteWXgajNuHT9Z577rkkPuqoo7y23XffvaRjIK5ddtnFy3v16pVz2xNOOMHLe/bsmcRvvPGG1xauRFvMz/iUU05J4latWnltl19+eRKnl/pA5eV7QuakSZPy5mnhStU/+MEPvHzQoEFJHK5+Hj5F9JZbbsl5nGrCGQoAIAoKCgAgCgoKACCKBj2HUptw7qNU4bLkK1asSOL0U/gkafTo0Ul8zTXXRDk+infaaacVvG16ziR06aWXenkxcybhEi/p20//8pe/eG333HNPwftFZT388MNJHC5JHy6hk54LCX/G+Za+D5dhaahPheUMBQAQBQUFABAFBQUAEEWjnkMpl/Sy5Oecc47Xlr7m/uqrr3ptTz75ZHk7hkSsa9CHH364l7/wwgsFvzZcpqNHjx5JfPPNN9etY6iY9NxH+BjfMC9Gei42tPPOO3v5rrvumsQbNmwo+ZjlxhkKACAKCgoAIAoueZUgvYzL6tWrvbauXbsmcfh0Ny55Vc7ChQsL3vZnP/uZl6eXbfn973+f97UdOnRI4quuusprO+6447w8vUr1rFmzCu4fGqfHHnssicMVpsPfHelVzcePH1/ejtUBZygAgCgoKACAKCgoAIAomEOpQbNm/tuy7777evnjjz+exOk5E8lfQiHfUgsor6efftrLX3/99SQOl7IfMWKEl19wwQVJfOCBB3ptRxxxhJdPnjw5icPlMzZu3OjlY8aMSeLPP/88V9fRRKR/P9x5551eW/jk18GDBycxcygAgEaPggIAiKIqLnmlPxW6fft2ry3MW7ZsWdIxOnXq5OXnn39+zm3bt2/v5elb9kLhZY70p6MnTJhQTBcRUXi5cc6cOUkcXvIKVxt+6623Cj5OeiXq9DEk6dZbb/XyBQsWFLxfNC2ffvqpl5uZl/ft27eS3SkZZygAgCgoKACAKCgoAIAoqmIOZenSpUm8atUqr23lypVe3r9//4r0KZ/07aDhaqPMm1SnsWPHJnH4JM+BAwd6+YUXXphzP+nbhCV/+YwnnniiDj1EU/bhhx96+ZYtW7y8VatWSdyiRQuvrZpuQecMBQAQBQUFABAFBQUAEIWFn6PIu7FZ4RsX4ZJLLkni3r17e23hk8tOOumkcnTBM23aNC9ftGiRl8+ePTuJ//a3v5W9P7VxzlntW9Wfco0b1NlC51zP2jerP0117Dz00ENePnTo0CROLw0kSffff39F+hSocexwhgIAiIKCAgCIoipuG77ttttqjAGgKVqxYkV9d6EknKEAAKKgoAAAoqCgAACiqIo5FADAfyxZsiRnW7hMSzXhDAUAEAUFBQAQBQUFABBFVSy9grph6RWUiKVXUCqWXgEAlA8FBQAQBQUFABAFBQUAEAUFBQAQBQUFABBFsUuvrJO0vBwdQck613cHCsC4qU6MHZSqxrFT1OdQAADIhUteAIAoKCgAgCgoKACAKCgoAIAoKCgAgCgoKACAKCgoAIAoKCgAgCgoKACAKCgoAIAoKCgAgCgoKACAKCgoAIAoGnxBMbPJZjY2G/c1sw9K3M9EM7s6bu9QzRg7KAXjJreKFBQzW2Zm28xss5mtyf5A2sQ+jnPur865rgX0Z7iZzQteO9I5d23sPuU49hfZ9+Krf8eX+7gNFWPHO/ZOZjbOzFab2UYzu8vMdiz3cRsixo13bDOzsWa2ysz+aWZzzeyQchyrkmco33fOtZF0pKSekq4KNzCzYh/41VC97Jxrk/o3t747VOUYOxmXKfP9HyrpQGXej6+9F0gwbjKGSPqppL6SdpX0sqQp5ThQxS95OedWSXpKmf8UMjNnZqPM7ENJH2a/NtDM3jKzTWY238x6fPV6MzvCzN4ws8/MbJqkFqm2483s41S+j5nNMLO1ZrbezCaY2UGSJkr6bvavl03ZbZPT2Gw+wsyWmNkGM3vczDqk2pyZjTSzD7N9/L2ZWbneM2QwdvR9SXc65zY459ZKulOZXxTIg3Gj/STNc84tdc59IelBSQcX+z4WouIFxcz2kdRf0pupLw+SdLSkg83sCEl/kPQzSbtJ+h9Jj1vmdL+5pFnKVNddJf1Z0tk5jvMNSbOVeXzovpI6SnrYObdY0kj95yyhXQ2vPVHSDZJ+IGnv7D4eDjYbKKmXpB7Z7U7NvrZT9gfeKc/bcISZrTOzv5vZ1U3kr6Q6Y+xkDhHE3zKztnm2b/IYN3pY0v5mdqBlLpGeJ+l/c2xbN865sv+TtEzSZkmblHmj7pLUMtvmJJ2Y2vZuSdcGr/9A0nGS+klareyji7Nt8yWNzcbHS/o4G39X0lpJzWroz3BlKnb6a5NT+5kk6eZUWxtJ/5a0b6rPx6baH5F0WYHvRRdl/mLYQVJ3Se9JurwSP4eG+I+x4x1nrKSXJO0haS9Jr2b3t3d9/5yq7R/jxjtOc0njs/vYLukjSfuV432v5F/Gg5xzc3K0rUzFnSWdZ2YXpb7WXFIHZd6QVS77LmUtz7HPfSQtd85tL6GvHSS98VXinNtsZuuV+YtjWfbL/0htv1WZAVAr59zSVPqOmf1O0q+V+esENWPsZFwnqZ2ktyT9S9K9ko6QtKaEfjYFjJuMa5Q5s9knu48fSXrezA5xzm0toa85Vcttw+kf1kpJ1znn2qX+tXLOTZX0iaSOwbXDXKd5KyV1ynE5ydXwtbTVygwySZKZtVbmVHhVbd9ICZz8yxgoTpMZO865bc65/3bOdXTOdZG0XtJC59yXdd13E9Rkxo2kwyVNc8597Jzb7pybLKm9yjCPUi0FJe1eSSPN7GjLaG1mA8xsZ2XuTtguabSZ7WhmgyV9J8d+XlNmMNyY3UcLM+uTbVujzLXn5jleO1XST8zscDPbSdL1kl51zi2r6zdnZqeb2TezcTdJV0t6rK77haTGP3Y6mlmH7Pd2jDJj5zd13S8a97iR9LqkIWb2TTPbwcx+LGlHSUsi7NtTdQXFObdA0ghJEyRtVOabHp5t+z9Jg7P5BklDJc3IsZ8vlLkr5gBJKyR9nN1ekp6XtEjSP8xsXQ2vnaPMf9bpygyQ/SX9sJD+ZyfINueZIDtJ0ttmtkXSX7L9v76QfSO/JjB29lfm+v0WSX9U5hr6M4XsG7k1gXFzk6S/KXOpdJOkX0g62zm3qZD9F8P8S4MAAJSm6s5QAAANEwUFABAFBQUAEAUFBQAQBQUFABBFUZ+UNzNuCatCzrmq/mAk46ZqrXPO7VHfnciHsVO1ahw7nKEATVeuJUSA2tQ4digoAIAoKCgAgCgoKACAKCgoAIAoKCgAgCgoKACAKCgoAIAoKCgAgCgoKACAKCgoAIAoKCgAgCgoKACAKIpabRgAmrI999zTy//1r395ebdu3ZL47LPPzrsvs/8sEr777rt7bcOGDcv5unPPPdfLH3nkkbzHqSTOUAAAUVBQAABRmHOFP7+mGh52881vfjOJe/fu7bUNGDCg4P0ceOCBXv7pp58m8bhx47y2xYsXe/mGDRsKPk4l8IAtlGihc65nfXcin/oYO+Hvhl/84hdJHF6KWr16tZfvt99+BR8nfcmrmN/D4e+j7t27F/zaiGocO5yhAACioKAAAKKgoAAAoqj624Z32203L3/66aeT+OCDD/bawlv4mjX7z7e3detWr619+/Zenr6eOXjwYK9t1apVXj5mzJgkvu+++3L2HQ3DHnvs4eXXXnttEh900EFeW9++fb08fe173rx5XtvMmTO9/I477qhTP1Ee4e29kyZN8vI2bdokcfr3hFTcnEksixYtqvgxC8UZCgAgCgoKACCKqr/kddRRR3l5jx49kvh3v/ud1/bggw96+a677prE7733ntd24okn5jxmeBte+OnYuXPn5u4wqt5ZZ53l5bfffruXd+rUKYnD2znz5ccee6zX1qdPHy9v3bp1El933XVF9BjlNGLECC9PX+KK6dFHH/XyQw45JInDS6v5LF++PFqfYuMMBQAQBQUFABAFBQUAEEXVz6GcdNJJOdtWrlzp5UuWLCl4v48//nhJbWiY+vXrl8TTp0/32sJ5kfQt5tdff73Xtn79+pzHOPXUU7180KBBXn7xxRcn8ZQpU7y2FStW5Nwvqlc4N3vyySfn3HbdunVefuONNyZxMXMo1YwzFABAFBQUAEAUFBQAQBRVP4cycODAnG2zZ8+uYE/QkF1++eVJXNtnS3r16pXE77//fsHHSC8LJH39cyjpJV7CJ/Qxh1J/Nm3a5OVffPGFl6eXcNphB/9v8PAzakOGDEnie+65x2v78ssvvTz9uaRwSZd8ws9NVRPOUAAAUVBQAABRVP0TG8Nb7dLLqYSnn3VxwAEHJPEpp5zitYXLv6Tdf//9Xv7SSy9F61OheGLj15122mle/uSTT6b747WFy6BcffXVJR0zfQlDkrp16+blCxcuLGm/ZcQTG2sQjodLL7003R+vLd/vz/SYk6Q5c+Z4eXr16dp+D8+aNSuJf/SjH3ltn3/+ed7XlglPbAQAlA8FBQAQBQUFABBF1c+hrF271svTT3AsZg6lS5cuXh7OfRx22GFJ/I1vfMNr27Ztm5en28M+pJffeO211wruX10wh/J1EydO9PILLrggiT/44AOvLX2bsPT1p3sWKry2PXnyZC8/9NBDk7iY25HLiDmUGuy9995ePnz48CQeO3as11bM789Qej6mtv0cfvjhSfzuu++WfMyImEMBAJQPBQUAEAUFBQAQRdUvvRLeuz106NAkDj8v8tFHH3n56aefnsS33HKL1/bJJ594+bRp05J4/PjxXlu4RHX6EaF//vOfvbbbbrstifv27StUh/T16nCOpNQ5E8lfTuWKK67IeUzJX7I+nLdB9Qh/N9xwww1J/MILL3ht6SV9JGnAgAHl61gDwBkKACAKCgoAIIqqv+Q1ZswYLz/jjDOSeMaMGXlfm17dc+7cuV7bOeec4+WbN28uuE/pbe+77z6vberUqUl87rnn5mxDZdXl9s580pc8unbtmveY6RWGw9WGwyWGUJ1eeeUVLw9/j5x11llJ/NBDD0U77oQJE5L4N7/5jdf24osvRjtOXXGGAgCIgoICAIiCggIAiKLql14Jpa9Z9+7d22sLl0hJz28888wz5e1YVnreZtGiRV5b9+7dy3JMll75uvTtvJK/DE7nzp29tsWLF3v5X//615z77devn5en501qW9p85cqVSdyzp79qRT3NobD0Sh21bdvWy6+88sokvuSSS/K+Nr1sU/g0x2Jce+21SXzTTTd5beHvxIhYegUAUD4UFABAFBQUAEAUDW4OpdqtWbMmicOl7cPr+rEwh1K79OcDwiXIw8+P5FtWPN88SW1zKOlHC6eX86hHzKEUaccdd/Ty6dOne3n//v0L3ld6uacFCxZ4beFnTdLLPYXS4y79qGDJX3pfkj777LOC+1cL5lAAAOVDQQEARFH1S680NOlLGeHqs6g/M2fOTOJ58+Z5benLYZI0YsSIgvfbrVu3JG7dunXebdevX1/wflGdwo8fFLOi+PLly708fYvxsmXLvLZwOZWbb745iY8//vicxzjzzDO9/JhjjvHyZ599tpCulowzFABAFBQUAEAUFBQAQBTcNhxZnz59kjh93V6S9txzz7Ick9uG68/rr7+exEceeaTXFv7f2muvvZK4Spar57bhGnTp0sXLR48encQXXXSR11bM78/DDjvMy8OlmfJp3759Eoe/V9LzOGF/wiX0hw0bVvAxa8FtwwCA8qGgAACioKAAAKLgcyhAEcLPmrRq1SqJw6VXQlUyb4JaTJkyxcuPPvrokvbz6KOPenn4mIRibNy4MYlnz57tteX7LEy5lnvKhTMUAEAUFBQAQBRc8gKKkF5qRfJXKg5v2SzmllLUn/QTD6WvL1eSFq4gHj5p8a233kri8BbjfE9lDFcxDm9BP/vss5O4V69eOfsUHuP222/Pecxy4AwFABAFBQUAEAUFBQAQRdXPocyfP9/L00uA33TTTV5buCx5fUjfwlfbbaRoePr16+fl6Z9x+PPmNuGGIZyTyDf3Fc5RhNumfz/deuuteY+bHi/t2rXz2op58mO6T1u3bvXaIj6hsSCcoQAAoqCgAACiqPpLXg888ICXjx8/PonDp5y9+eabXr5ly5ay9esrLVq08PLLLrssiRcsWFD246Oy0rcJS/kvj8yYMaPc3UGVOfHEEwveNn3Jqy63mKdvVR41apTX9sorr5S831JwhgIAiIKCAgCIgoICAIii6udQJk6c6OWdO3dO4ksvvdRr23vvvb38lltuSeJXX321DL3z50wkaZdddkni8LZmND7p6+Dhshz33ntvpbuDErz33ntefvLJJ9dTT3JLrzA8Z84cr23atGlJvHbt2or1qSacoQAAoqCgAACioKAAAKKo+jmU0OWXX57E4edQxowZ4+UDBw5M4o8//thrmz59es5jvPPOO14eLoswePDgJO7Tp4/Xln5K2/PPP5/zGGgc0p8fyLc8OapX+neKJD3xxBNe/thjjyVx+MTOYtx9991enl4m5amnnvLawnmdDRs2JPH27dtL7kO5cYYCAIiCggIAiMKK+ci/mVX1I+g6duzo5emVYdOXqST/CWjF+tOf/pTE4fIaM2fOLHm/pXLOVfWyxtU+booRXiodNGhQEodL//Ts2bMifaqDhc65qu5kYxo7jUyNY4czFABAFBQUAEAUFBQAQBQN7rbhfFatWuXlU6dOrTEGShXOOabz8FZPoKnhDAUAEAUFBQAQBQUFABBFo/ocSlPF51BQIj6HglLxORQAQPlQUAAAUVBQAABRUFAAAFFQUAAAUVBQAABRFLv0yjpJy8vREZSsc313oACMm+rE2EGpahw7RX0OBQCAXLjkBQCIgoICAIiCggIAiIKCAgCIgoICAIiCggIAiIKCAo1z6QYAAAATSURBVACIgoICAIiCggIAiOL/AyWSUozBZgCEAAAAAElFTkSuQmCC\n" }, "metadata": {} } ] }, { "cell_type": "markdown", "source": [ "I reckon you probably know enough to download imagenet and reproduce AlexNet. You probably have to do that locally (or pay Google for CoLab space) as there are more than a million images involved though. Good project for the Xmas holidays... build AlexNet, build VGG, build ResNet200." ], "metadata": { "id": "g-9j4p8flcpB" } } ] }