Fixed problems with MNIST1D
This commit is contained in:
@@ -44,7 +44,8 @@
|
|||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# Run this if you're in a Colab to install MNIST 1D repository\n",
|
"# Run this if you're in a Colab to install MNIST 1D repository\n",
|
||||||
"!pip install git+https://github.com/greydanus/mnist1d"
|
"!pip install git+https://github.com/greydanus/mnist1d\n",
|
||||||
|
"!git clone https://github.com/greydanus/mnist1d"
|
||||||
],
|
],
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
@@ -95,6 +96,12 @@
|
|||||||
"id": "I-vm_gh5xTJs"
|
"id": "I-vm_gh5xTJs"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
|
"from mnist1d.data import get_dataset, get_dataset_args\n",
|
||||||
|
"from mnist1d.utils import set_seed, to_pickle, from_pickle\n",
|
||||||
|
"\n",
|
||||||
|
"import sys ; sys.path.append('./mnist1d/notebooks')\n",
|
||||||
|
"from train import get_model_args, train_model\n",
|
||||||
|
"\n",
|
||||||
"args = mnist1d.get_dataset_args()\n",
|
"args = mnist1d.get_dataset_args()\n",
|
||||||
"data = mnist1d.get_dataset(args=args) # by default, this will download a pre-made dataset from the GitHub repo\n",
|
"data = mnist1d.get_dataset(args=args) # by default, this will download a pre-made dataset from the GitHub repo\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -210,7 +217,7 @@
|
|||||||
" # we would return [1,1,0,0,1]\n",
|
" # we would return [1,1,0,0,1]\n",
|
||||||
" # Remember that these are torch tensors and not numpy arrays\n",
|
" # Remember that these are torch tensors and not numpy arrays\n",
|
||||||
" # Replace this function:\n",
|
" # Replace this function:\n",
|
||||||
" mask = torch.ones_like(scores)\n",
|
" mask = torch.ones_like(absolute_weights)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return mask"
|
" return mask"
|
||||||
@@ -237,7 +244,6 @@
|
|||||||
"def find_lottery_ticket(model, dataset, args, sparsity_schedule, criteria_fn=None, **kwargs):\n",
|
"def find_lottery_ticket(model, dataset, args, sparsity_schedule, criteria_fn=None, **kwargs):\n",
|
||||||
"\n",
|
"\n",
|
||||||
" criteria_fn = lambda init_params, final_params: final_params.abs()\n",
|
" criteria_fn = lambda init_params, final_params: final_params.abs()\n",
|
||||||
"\n",
|
|
||||||
" init_params = model.get_layer_vecs()\n",
|
" init_params = model.get_layer_vecs()\n",
|
||||||
" stats = {'train_losses':[], 'test_losses':[], 'train_accs':[], 'test_accs':[]}\n",
|
" stats = {'train_losses':[], 'test_losses':[], 'train_accs':[], 'test_accs':[]}\n",
|
||||||
" models = []\n",
|
" models = []\n",
|
||||||
@@ -253,7 +259,7 @@
|
|||||||
" model.set_layer_masks(masks)\n",
|
" model.set_layer_masks(masks)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # training process\n",
|
" # training process\n",
|
||||||
" results = mnist1d.train_model(dataset, model, args)\n",
|
" results = train_model(dataset, model, args)\n",
|
||||||
" model = results['checkpoints'][-1]\n",
|
" model = results['checkpoints'][-1]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # store stats\n",
|
" # store stats\n",
|
||||||
@@ -291,7 +297,8 @@
|
|||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# train settings\n",
|
"# train settings\n",
|
||||||
"model_args = mnist1d.get_model_args()\n",
|
"from train import get_model_args, train_model\n",
|
||||||
|
"model_args = get_model_args()\n",
|
||||||
"model_args.total_steps = 1501\n",
|
"model_args.total_steps = 1501\n",
|
||||||
"model_args.hidden_size = 500\n",
|
"model_args.hidden_size = 500\n",
|
||||||
"model_args.print_every = 5000 # print never\n",
|
"model_args.print_every = 5000 # print never\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user