From 0d135f1ee7fc6452b1afede82ee4f73f8683266f Mon Sep 17 00:00:00 2001 From: udlbook <110402648+udlbook@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:55:44 -0400 Subject: [PATCH] Fixed problems with MNIST1D --- Notebooks/Chap20/20_3_Lottery_Tickets.ipynb | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/Notebooks/Chap20/20_3_Lottery_Tickets.ipynb b/Notebooks/Chap20/20_3_Lottery_Tickets.ipynb index 182cfa3..aafad0f 100644 --- a/Notebooks/Chap20/20_3_Lottery_Tickets.ipynb +++ b/Notebooks/Chap20/20_3_Lottery_Tickets.ipynb @@ -44,7 +44,8 @@ }, "source": [ "# Run this if you're in a Colab to install MNIST 1D repository\n", - "!pip install git+https://github.com/greydanus/mnist1d" + "!pip install git+https://github.com/greydanus/mnist1d\n", + "!git clone https://github.com/greydanus/mnist1d" ], "execution_count": null, "outputs": [] @@ -95,6 +96,12 @@ "id": "I-vm_gh5xTJs" }, "source": [ + "from mnist1d.data import get_dataset, get_dataset_args\n", + "from mnist1d.utils import set_seed, to_pickle, from_pickle\n", + "\n", + "import sys ; sys.path.append('./mnist1d/notebooks')\n", + "from train import get_model_args, train_model\n", + "\n", "args = mnist1d.get_dataset_args()\n", "data = mnist1d.get_dataset(args=args) # by default, this will download a pre-made dataset from the GitHub repo\n", "\n", @@ -210,7 +217,7 @@ " # we would return [1,1,0,0,1]\n", " # Remember that these are torch tensors and not numpy arrays\n", " # Replace this function:\n", - " mask = torch.ones_like(scores)\n", + " mask = torch.ones_like(absolute_weights)\n", "\n", "\n", " return mask" @@ -237,7 +244,6 @@ "def find_lottery_ticket(model, dataset, args, sparsity_schedule, criteria_fn=None, **kwargs):\n", "\n", " criteria_fn = lambda init_params, final_params: final_params.abs()\n", - "\n", " init_params = model.get_layer_vecs()\n", " stats = {'train_losses':[], 'test_losses':[], 'train_accs':[], 'test_accs':[]}\n", " models = []\n", @@ -253,7 +259,7 @@ " model.set_layer_masks(masks)\n", "\n", " # training process\n", - " results = mnist1d.train_model(dataset, model, args)\n", + " results = train_model(dataset, model, args)\n", " model = results['checkpoints'][-1]\n", "\n", " # store stats\n", @@ -291,7 +297,8 @@ }, "source": [ "# train settings\n", - "model_args = mnist1d.get_model_args()\n", + "from train import get_model_args, train_model\n", + "model_args = get_model_args()\n", "model_args.total_steps = 1501\n", "model_args.hidden_size = 500\n", "model_args.print_every = 5000 # print never\n",