384 lines
14 KiB
Plaintext
384 lines
14 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": [],
|
|
"include_colab_link": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap20/20_3_Lottery_Tickets.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "dKUcDM76bHx3"
|
|
},
|
|
"source": [
|
|
"# **Notebook 20.3: Lottery tickets**\n",
|
|
"\n",
|
|
"This notebook investigates the phenomenon of lottery tickets as discussed in section 20.2.7. This notebook is highly derivative of the MNIST-1D code hosted by Sam Greydanus at https://github.com/greydanus/mnist1d. \n",
|
|
"\n",
|
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
|
"\n",
|
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Sg2i1QmhKW5d"
|
|
},
|
|
"source": [
|
|
"# Run this if you're in a Colab to install MNIST 1D repository\n",
|
|
"!pip install git+https://github.com/greydanus/mnist1d\n",
|
|
"!git clone https://github.com/greydanus/mnist1d"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Lottery tickets\n",
|
|
"\n",
|
|
"Lottery tickets were first identified by [Frankle and Carbin (2018)](https://arxiv.org/abs/1803.03635). They noted that after training a network, they could set the smaller weights to zero and clamp them there and retrain to get a network that was sparser (had fewer parameters) but could actually perform better. So within the neural network there lie smaller sub-networks which are superior. If we knew what these were, we could train them from scratch, but there is currently no way of finding out."
|
|
],
|
|
"metadata": {
|
|
"id": "97g8gY5XdcKR"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "KaQo7QhvXvid"
|
|
},
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"\n",
|
|
"import mnist1d\n",
|
|
"import copy"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "nre26wEOfZsM"
|
|
},
|
|
"source": [
|
|
"## Get the MNIST1D dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "I-vm_gh5xTJs"
|
|
},
|
|
"source": [
|
|
"from mnist1d.data import get_dataset, get_dataset_args\n",
|
|
"from mnist1d.utils import set_seed, to_pickle, from_pickle\n",
|
|
"\n",
|
|
"import sys ; sys.path.append('./mnist1d/notebooks')\n",
|
|
"from train import get_model_args, train_model\n",
|
|
"\n",
|
|
"args = mnist1d.get_dataset_args()\n",
|
|
"data = mnist1d.get_dataset(args=args) # by default, this will download a pre-made dataset from the GitHub repo\n",
|
|
"\n",
|
|
"print(\"Examples in training set: {}\".format(len(data['y'])))\n",
|
|
"print(\"Examples in test set: {}\".format(len(data['y_test'])))\n",
|
|
"print(\"Length of each input: {}\".format(data['x'].shape[-1]))\n",
|
|
"print(\"Number of classes: {}\".format(len(data['templates']['y'])))"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "O2vy0FKjfDwr"
|
|
},
|
|
"source": [
|
|
"## Make an MLP that can be masked\n",
|
|
"These parameter-wise binary masks are how we will represent sparsity in this project. There's not a great PyTorch API for this yet, so here's a temporary solution."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "uBx5gNW-mqH_"
|
|
},
|
|
"source": [
|
|
"# Class to represent linear layer where some of the weights are forced to zero.\n",
|
|
"class SparseLinear(torch.nn.Module):\n",
|
|
" def __init__(self, x_size, y_size):\n",
|
|
" super(SparseLinear, self).__init__()\n",
|
|
" self.linear = torch.nn.Linear(x_size, y_size)\n",
|
|
" param_vec = torch.cat([p.flatten() for p in self.parameters()])\n",
|
|
" self.mask = torch.ones_like(param_vec)\n",
|
|
"\n",
|
|
" def forward(self, x, apply_mask=True):\n",
|
|
" if apply_mask:\n",
|
|
" self.apply_mask()\n",
|
|
" return self.linear(x)\n",
|
|
"\n",
|
|
" def update_mask(self, new_mask):\n",
|
|
" self.mask = new_mask\n",
|
|
" self.apply_mask()\n",
|
|
"\n",
|
|
" def apply_mask(self):\n",
|
|
" self.vec2param(self.param2vec())\n",
|
|
"\n",
|
|
" def param2vec(self):\n",
|
|
" vec = torch.cat([p.flatten() for p in self.parameters()])\n",
|
|
" return self.mask * vec\n",
|
|
"\n",
|
|
" def vec2param(self, vec):\n",
|
|
" pointer = 0\n",
|
|
" for param in self.parameters():\n",
|
|
" param_len = np.cumprod(param.shape)[-1]\n",
|
|
" new_param = vec[pointer:pointer+param_len].reshape(param.shape)\n",
|
|
" param.data = new_param.data\n",
|
|
" pointer += param_len\n",
|
|
"\n",
|
|
"# A two layer residual network where the linear layers are sparse\n",
|
|
"class SparseMLP(torch.nn.Module):\n",
|
|
" def __init__(self, input_size, output_size, hidden_size=100):\n",
|
|
" super(SparseMLP, self).__init__()\n",
|
|
" self.linear1 = SparseLinear(input_size, hidden_size)\n",
|
|
" self.linear2 = SparseLinear(hidden_size, hidden_size)\n",
|
|
" self.linear3 = SparseLinear(hidden_size, output_size)\n",
|
|
" self.layers = [self.linear1, self.linear2, self.linear3]\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" h = torch.relu(self.linear1(x))\n",
|
|
" h = h + torch.relu(self.linear2(h))\n",
|
|
" h = self.linear3(h)\n",
|
|
" return h\n",
|
|
"\n",
|
|
" def get_layer_masks(self):\n",
|
|
" return [l.mask for l in self.layers]\n",
|
|
"\n",
|
|
" def set_layer_masks(self, new_masks):\n",
|
|
" for i, l in enumerate(self.layers):\n",
|
|
" l.update_mask(new_masks[i])\n",
|
|
"\n",
|
|
" def get_layer_vecs(self):\n",
|
|
" return [l.param2vec() for l in self.layers]\n",
|
|
"\n",
|
|
" def set_layer_vecs(self, vecs):\n",
|
|
" for i, l in enumerate(self.layers):\n",
|
|
" l.vec2param(vecs[i])"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "2hwmH2vIbHin"
|
|
},
|
|
"source": [
|
|
"Now we need a routine that takes the weights from the model and returns a mask that identifies the positions with the lowest magnitude. These will be the weights that we mask."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Md2F9WDgYSqT"
|
|
},
|
|
"source": [
|
|
"# absolute weights -- absolute values of all the weights from the model in a long vector\n",
|
|
"# percent_sparse: how much to sparsify the model\n",
|
|
"def get_mask(absolute_weights, percent_sparse):\n",
|
|
" # TODO -- Write a function that returns a mask that has a zero\n",
|
|
" # everywhere for the lowest \"percent_sparse\" of the absolute weights.\n",
|
|
" # E.g. if absolute_weights contains [5,6,0,1,7] and we want percent_sparse of 40%,\n",
|
|
" # we would return [1,1,0,0,1]\n",
|
|
" # Remember that these are torch tensors and not numpy arrays\n",
|
|
" # Replace this function:\n",
|
|
" mask = torch.ones_like(absolute_weights)\n",
|
|
"\n",
|
|
"\n",
|
|
" return mask"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "z0McGMV-a3Xo"
|
|
},
|
|
"source": [
|
|
"## The prune-and-retrain cycle\n",
|
|
"This is the core method for finding a lottery ticket. We train a model for a fixed number of epochs, prune it, and then re-train and re-prune. We repeat this cycle until we achieve the desired level of sparsity."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "5idcbyA3Ylz_"
|
|
},
|
|
"source": [
|
|
"def find_lottery_ticket(model, dataset, args, sparsity_schedule, criteria_fn=None, **kwargs):\n",
|
|
"\n",
|
|
" criteria_fn = lambda init_params, final_params: final_params.abs()\n",
|
|
" init_params = model.get_layer_vecs()\n",
|
|
" stats = {'train_losses':[], 'test_losses':[], 'train_accs':[], 'test_accs':[]}\n",
|
|
" models = []\n",
|
|
" for i, percent_sparse in enumerate(sparsity_schedule):\n",
|
|
"\n",
|
|
" # layer-wise pruning, where pruning heuristic is determined by criteria_fn\n",
|
|
" final_params = model.get_layer_vecs()\n",
|
|
" scores = [criteria_fn(ip, fp) for ip, fp in zip(init_params, final_params)]\n",
|
|
" masks = [get_mask(s, percent_sparse) for s in scores]\n",
|
|
"\n",
|
|
" # update model with mask and init parameters\n",
|
|
" model.set_layer_vecs(init_params)\n",
|
|
" model.set_layer_masks(masks)\n",
|
|
"\n",
|
|
" # training process\n",
|
|
" results = train_model(dataset, model, args)\n",
|
|
" model = results['checkpoints'][-1]\n",
|
|
"\n",
|
|
" # store stats\n",
|
|
" stats['train_losses'].append(results['train_losses'])\n",
|
|
" stats['test_losses'].append(results['test_losses'])\n",
|
|
" stats['train_accs'].append(results['train_acc'])\n",
|
|
" stats['test_accs'].append(results['test_acc'])\n",
|
|
"\n",
|
|
" # print progress\n",
|
|
" if (i+1) % 1 == 0:\n",
|
|
" print('\\tretrain #{}, sparsity {:.2f}, final_train_loss {:.3e}, max_acc {:.1f}, last_acc {:.1f}, mean_acc {:.1f}'\n",
|
|
" .format(i+1, percent_sparse, results['train_losses'][-1], np.max(results['test_acc']),\n",
|
|
" results['test_acc'][-1], np.mean(results['test_acc']) ))\n",
|
|
" models.append(copy.deepcopy(model))\n",
|
|
"\n",
|
|
" stats = {k: np.stack(v) for k, v in stats.items()}\n",
|
|
" return models, stats"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "m4lokvdD4DKI"
|
|
},
|
|
"source": [
|
|
"## Choose hyperparameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "OUe7-b-7Yl2c"
|
|
},
|
|
"source": [
|
|
"# train settings\n",
|
|
"from train import get_model_args, train_model\n",
|
|
"model_args = get_model_args()\n",
|
|
"model_args.total_steps = 1501\n",
|
|
"model_args.hidden_size = 500\n",
|
|
"model_args.print_every = 5000 # print never\n",
|
|
"model_args.eval_every = 100\n",
|
|
"model_args.learning_rate = 2e-2\n",
|
|
"model_args.device = str('cpu')"
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "hVgDM5rI4J65"
|
|
},
|
|
"source": [
|
|
"Find the lottery ticket by repeatedly training and then pruning weights based on their magnitudes. We'll remove 1% of the weights each time. This is going to take half an hour or so. Go and have lunch or whatever."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# sparsity settings - we will train 100 models with progressively increasing sparsity\n",
|
|
"num_retrains = 100\n",
|
|
"sparsity_schedule = np.linspace(0,1.,num_retrains)\n",
|
|
"\n",
|
|
"print(\"Magnitude pruning\")\n",
|
|
"mnist1d.set_seed(model_args.seed)\n",
|
|
"model = SparseMLP(model_args.input_size, model_args.output_size, hidden_size=model_args.hidden_size)\n",
|
|
"\n",
|
|
"criteria_fn = lambda init_params, final_params: final_params.abs()\n",
|
|
"lott_models, lott_stats = find_lottery_ticket(model, data, model_args, sparsity_schedule, criteria_fn=criteria_fn, prune_print_every=1)"
|
|
],
|
|
"metadata": {
|
|
"id": "M25YpCuS1Gn0"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"test_losses = lott_stats['test_losses'][:,-1]\n",
|
|
"test_accs = lott_stats['test_accs'][:,-1]\n",
|
|
"\n",
|
|
"fig,ax = plt.subplots()\n",
|
|
"ax.plot(sparsity_schedule, test_losses,'r-')\n",
|
|
"ax.plot((sparsity_schedule[0], sparsity_schedule[-1]),(test_losses[0], test_losses[0]),'k--',label='dense')\n",
|
|
"ax.set_xlabel('Sparsity')\n",
|
|
"ax.set_ylabel('Loss')\n",
|
|
"ax.set_xlim(0,1)\n",
|
|
"ax.legend()\n",
|
|
"plt.show()\n",
|
|
"\n",
|
|
"fig,ax = plt.subplots()\n",
|
|
"ax.plot(sparsity_schedule, 100-test_accs,'r-')\n",
|
|
"ax.plot((sparsity_schedule[0], sparsity_schedule[-1]),(100-test_accs[0], 100-test_accs[0]),'k--',label='dense')\n",
|
|
"ax.set_xlabel('Sparsity')\n",
|
|
"ax.set_ylabel('Percent Error')\n",
|
|
"ax.set_xlim(0,1)\n",
|
|
"ax.set_ylim(0,100)\n",
|
|
"ax.legend()\n",
|
|
"plt.show()\n"
|
|
],
|
|
"metadata": {
|
|
"id": "TCs-kt6-3xHB"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"You should see that the test loss decreases and the errors decrease with more as the network gets sparser. The dashed line represents the original dense (unpruned) network. We have identified a simpler network that was \"inside\" the original network for which the results are superior. Of course if we make it too sparse, then it gets worse again.\n",
|
|
"\n",
|
|
"This phenomenon is explored much further in the original notebook by Sam Greydanus which can be found [here](https://github.com/greydanus/mnist1d)."
|
|
],
|
|
"metadata": {
|
|
"id": "CEj5_ZEHcRpw"
|
|
}
|
|
}
|
|
]
|
|
} |