Notebook 9.5

This commit is contained in:
Youcef Rahal
2024-05-12 15:27:57 -04:00
parent 2ac42e70d3
commit 0233131b07

View File

@@ -95,7 +95,7 @@
"D_k = 200 # Hidden dimensions\n", "D_k = 200 # Hidden dimensions\n",
"D_o = 10 # Output dimensions\n", "D_o = 10 # Output dimensions\n",
"\n", "\n",
"# Define a model with two hidden layers of size 100\n", "# Define a model with two hidden layers of size 200\n",
"# And ReLU activations between them\n", "# And ReLU activations between them\n",
"model = nn.Sequential(\n", "model = nn.Sequential(\n",
"nn.Linear(D_i, D_k),\n", "nn.Linear(D_i, D_k),\n",
@@ -186,7 +186,7 @@
"ax.plot(errors_test,'b-',label='test')\n", "ax.plot(errors_test,'b-',label='test')\n",
"ax.set_ylim(0,100); ax.set_xlim(0,n_epoch)\n", "ax.set_ylim(0,100); ax.set_xlim(0,n_epoch)\n",
"ax.set_xlabel('Epoch'); ax.set_ylabel('Error')\n", "ax.set_xlabel('Epoch'); ax.set_ylabel('Error')\n",
"ax.set_title('TrainError %3.2f, Test Error %3.2f'%(errors_train[-1],errors_test[-1]))\n", "ax.set_title('Train Error %3.2f, Test Error %3.2f'%(errors_train[-1],errors_test[-1]))\n",
"ax.legend()\n", "ax.legend()\n",
"plt.show()" "plt.show()"
], ],
@@ -233,7 +233,7 @@
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"n_data_orig = data['x'].shape[0]\n", "n_data_orig = data['x'].shape[0]\n",
"# We'll double the amount o fdata\n", "# We'll double the amount of data\n",
"n_data_augment = n_data_orig+4000\n", "n_data_augment = n_data_orig+4000\n",
"augmented_x = np.zeros((n_data_augment, D_i))\n", "augmented_x = np.zeros((n_data_augment, D_i))\n",
"augmented_y = np.zeros(n_data_augment)\n", "augmented_y = np.zeros(n_data_augment)\n",
@@ -343,4 +343,4 @@
} }
} }
] ]
} }