Created using Colaboratory
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
"provenance": [],
|
"provenance": [],
|
||||||
"authorship_tag": "ABX9TyN1qtywBuyezaVMnc9MI7x2",
|
"authorship_tag": "ABX9TyNJodaaCLMRWL9vTl8B/iLI",
|
||||||
"include_colab_link": true
|
"include_colab_link": true
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
@@ -188,9 +188,9 @@
|
|||||||
"scheduler = StepLR(optimizer, step_size=20, gamma=0.5)\n",
|
"scheduler = StepLR(optimizer, step_size=20, gamma=0.5)\n",
|
||||||
"# create 100 dummy data points and store in data loader class\n",
|
"# create 100 dummy data points and store in data loader class\n",
|
||||||
"x_train = torch.tensor(train_data_x.transpose().astype('float32'))\n",
|
"x_train = torch.tensor(train_data_x.transpose().astype('float32'))\n",
|
||||||
"y_train = torch.tensor(train_data_y.astype('long'))\n",
|
"y_train = torch.tensor(train_data_y.astype('long')).long()\n",
|
||||||
"x_val= torch.tensor(val_data_x.transpose().astype('float32'))\n",
|
"x_val= torch.tensor(val_data_x.transpose().astype('float32'))\n",
|
||||||
"y_val = torch.tensor(val_data_y.astype('long'))\n",
|
"y_val = torch.tensor(val_data_y.astype('long')).long()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# load the data into a class that creates the batches\n",
|
"# load the data into a class that creates the batches\n",
|
||||||
"data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
|
"data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=100, shuffle=True, worker_init_fn=np.random.seed(1))\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user