Compare commits

..

34 Commits

Author SHA1 Message Date
udlbook
bd12e774a4 Add files via upload 2024-03-06 17:33:19 -05:00
udlbook
e6c3938567 Created using Colaboratory 2024-03-05 12:12:54 -05:00
udlbook
50c93469d5 Created using Colaboratory 2024-03-05 09:24:49 -05:00
udlbook
666e2de7d8 Created using Colaboratory 2024-03-04 16:28:34 -05:00
udlbook
e947b261f8 Created using Colaboratory 2024-03-04 12:26:07 -05:00
udlbook
30801a1d2b Created using Colaboratory 2024-03-04 11:45:49 -05:00
udlbook
22d5bc320f Created using Colaboratory 2024-03-04 10:06:34 -05:00
udlbook
5c0fd0057f Created using Colaboratory 2024-03-04 09:43:56 -05:00
udlbook
9b2b30d4cc Update 17_3_Importance_Sampling.ipynb 2024-02-23 12:32:39 -05:00
udlbook
46e119fcf2 Add files via upload 2024-02-17 13:45:26 -05:00
udlbook
f197be3554 Created using Colaboratory 2024-02-17 12:37:25 -05:00
udlbook
0fa468cf2c Created using Colaboratory 2024-02-17 12:35:18 -05:00
udlbook
e11989bd78 Fixed ambiguity of variable name. 2024-02-17 10:07:40 -05:00
udlbook
566120cc48 Update index.html 2024-02-15 16:52:46 -05:00
udlbook
9f2449fcde Add files via upload 2024-02-15 16:51:27 -05:00
udlbook
025b677457 Merge pull request #150 from yrahal/main
Fix minor typos in Chapter 6 notebooks
2024-02-12 13:11:23 -05:00
Youcef Rahal
435971e3e2 Fix typos in 6_5_Adam.ipynb 2024-02-09 03:55:11 -05:00
Youcef Rahal
6e76cb9b96 Fix typos in 6_4_Momentum.ipynb 2024-02-07 20:17:49 -05:00
Youcef Rahal
732fc6f0b7 Fix issues typos in 6_3_Stochastic_Gradient_Descent.ipynb 2024-02-06 20:48:25 -05:00
udlbook
f2a3fab832 Created using Colaboratory 2024-02-06 18:45:05 -05:00
Youcef Rahal
8e3008673d Fix minor typos in 6_1_Line_Search.ipynb and 6_2_Gradient_Descent.ipynb 2024-02-04 11:03:14 -05:00
udlbook
07bcc98a85 Created using Colaboratory 2024-02-01 20:19:34 +00:00
udlbook
f4fa3e8397 Created using Colaboratory 2024-02-01 20:13:01 +00:00
udlbook
21cff37c72 Update index.html 2024-01-28 18:19:44 +00:00
udlbook
187c6a7352 Add files via upload 2024-01-28 10:01:17 +00:00
udlbook
8e4a0d4daf Add files via upload 2024-01-26 14:37:08 +00:00
udlbook
23b5affab3 Update 5_1_Least_Squares_Loss.ipynb 2024-01-25 16:01:23 +00:00
udlbook
4fb8ffe622 Merge pull request #144 from yrahal/main
Fix some typos in Notebooks/Chap05/5_1_Least_Squares_Loss.ipynb
2024-01-25 15:59:23 +00:00
Youcef Rahal
2adc1da566 Fix some typpos in Notebooks/Chap05/5_1_Least_Squares_Loss.ipynb 2024-01-25 10:16:46 -05:00
Youcef Rahal
6e4551a69f Fix some typpos in Notebooks/Chap05/5_1_Least_Squares_Loss.ipynb 2024-01-25 10:14:01 -05:00
udlbook
65c685706a Update 9_2_Implicit_Regularization.ipynb 2024-01-25 09:46:01 +00:00
udlbook
934f5f7748 Created using Colaboratory 2024-01-24 10:56:22 -05:00
udlbook
365cb41bba Update index.html 2024-01-23 10:54:43 +00:00
udlbook
4855761fb2 Update index.html 2024-01-19 15:04:48 -05:00
22 changed files with 2377 additions and 520 deletions

View File

@@ -0,0 +1,401 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyP9fLqBQPgcYJB1KXs3Scp/",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Blogs/BorealisGradientFlow.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Gradient flow\n",
"\n",
"This notebook replicates some of the results in the the Borealis AI [blog](https://www.borealisai.com/research-blogs/gradient-flow/) on gradient flow. \n"
],
"metadata": {
"id": "ucrRRJ4dq8_d"
}
},
{
"cell_type": "code",
"source": [
"# Import relevant libraries\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.linalg import expm\n",
"from matplotlib import cm\n",
"from matplotlib.colors import ListedColormap"
],
"metadata": {
"id": "_IQFHZEMZE8T"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Create the three data points that are used to train the linear model in the blog. Each input point is a column in $\\mathbf{X}$ and consists of the $x$ position in the plot and the value 1, which is used to allow the model to fit bias terms neatly."
],
"metadata": {
"id": "NwgUP3MSriiJ"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cJNZ2VIcYsD8"
},
"outputs": [],
"source": [
"X = np.array([[0.2, 0.4, 0.8],[1,1,1]])\n",
"y = np.array([[-0.1],[0.15],[0.3]])\n",
"D = X.shape[0]\n",
"I = X.shape[1]\n",
"\n",
"print(\"X=\\n\",X)\n",
"print(\"y=\\n\",y)"
]
},
{
"cell_type": "code",
"source": [
"# Draw the three data points\n",
"fig, ax = plt.subplots()\n",
"ax.plot(X[0:1,:],y.T,'ro')\n",
"ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n",
"ax.set_xlabel('x'); ax.set_ylabel('y')\n",
"plt.show()"
],
"metadata": {
"id": "FpFlD4nUZDRt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Compute the evolution of the residuals, loss, and parameters as a function of time."
],
"metadata": {
"id": "H2LBR1DasQej"
}
},
{
"cell_type": "code",
"source": [
"# Discretized time to evaluate quantities at\n",
"t_all = np.arange(0,20,0.01)\n",
"nT = t_all.shape[0]\n",
"\n",
"# Initial parameters, and initial function output at training points\n",
"phi_0 = np.array([[-0.05],[-0.4]])\n",
"f_0 = X.T @ phi_0\n",
"\n",
"# Precompute pseudoinverse term (not a very sensible numerical implementation, but it works...)\n",
"XXTInvX = np.linalg.inv(X@X.T)@X\n",
"\n",
"# Create arrays to hold function at data points over time, residual over time, parameters over time\n",
"f_all = np.zeros((I,nT))\n",
"f_minus_y_all = np.zeros((I,nT))\n",
"phi_t_all = np.zeros((D,nT))\n",
"\n",
"# For each time, compute function, residual, and parameters at each time.\n",
"for t in range(len(t_all)):\n",
" f = y + expm(-X.T@X * t_all[t]) @ (f_0-y)\n",
" f_all[:,t:t+1] = f\n",
" f_minus_y_all[:,t:t+1] = f-y\n",
" phi_t_all[:,t:t+1] = phi_0 - XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ (f_0-y)"
],
"metadata": {
"id": "wfF_oTS5Z4Wi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Plot the results that were calculated in the previous cell"
],
"metadata": {
"id": "9jSjOOFutJUE"
}
},
{
"cell_type": "code",
"source": [
"# Plot function at data points\n",
"fig, ax = plt.subplots()\n",
"ax.plot(t_all,np.squeeze(f_all[0,:]),'r-', label='$f[x_{0},\\phi]$')\n",
"ax.plot(t_all,np.squeeze(f_all[1,:]),'g-', label='$f[x_{1},\\phi]$')\n",
"ax.plot(t_all,np.squeeze(f_all[2,:]),'b-', label='$f[x_{2},\\phi]$')\n",
"ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.5,0.5])\n",
"ax.set_xlabel('t'); ax.set_ylabel('f')\n",
"plt.legend(loc=\"lower right\")\n",
"plt.show()\n",
"\n",
"# Plot residual\n",
"fig, ax = plt.subplots()\n",
"ax.plot(t_all,np.squeeze(f_minus_y_all[0,:]),'r-', label='$f[x_{0},\\phi]-y_{0}$')\n",
"ax.plot(t_all,np.squeeze(f_minus_y_all[1,:]),'g-', label='$f[x_{1},\\phi]-y_{1}$')\n",
"ax.plot(t_all,np.squeeze(f_minus_y_all[2,:]),'b-', label='$f[x_{2},\\phi]-y_{2}$')\n",
"ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.5,0.5])\n",
"ax.set_xlabel('t'); ax.set_ylabel('f-y')\n",
"plt.legend(loc=\"lower right\")\n",
"plt.show()\n",
"\n",
"# Plot loss (sum of residuals)\n",
"fig, ax = plt.subplots()\n",
"square_error = 0.5 * np.sum(f_minus_y_all * f_minus_y_all, axis=0)\n",
"ax.plot(t_all, square_error,'k-')\n",
"ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.0,0.25])\n",
"ax.set_xlabel('t'); ax.set_ylabel('Loss')\n",
"plt.show()\n",
"\n",
"# Plot parameters\n",
"fig, ax = plt.subplots()\n",
"ax.plot(t_all, np.squeeze(phi_t_all[0,:]),'c-',label='$\\phi_{0}$')\n",
"ax.plot(t_all, np.squeeze(phi_t_all[1,:]),'m-',label='$\\phi_{1}$')\n",
"ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-1,1])\n",
"ax.set_xlabel('t'); ax.set_ylabel('$\\phi$')\n",
"plt.legend(loc=\"lower right\")\n",
"plt.show()"
],
"metadata": {
"id": "G9IwgwKltHz5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Define the model and the loss function"
],
"metadata": {
"id": "N6VaUq2swa8D"
}
},
{
"cell_type": "code",
"source": [
"# Model is just a straight line with intercept phi[0] and slope phi[1]\n",
"def model(phi,x):\n",
" y_pred = phi[0]+phi[1] * x\n",
" return y_pred\n",
"\n",
"# Loss function is 0.5 times sum of squares of residuals for training data\n",
"def compute_loss(data_x, data_y, model, phi):\n",
" pred_y = model(phi, data_x)\n",
" loss = 0.5 * np.sum((pred_y-data_y)*(pred_y-data_y))\n",
" return loss"
],
"metadata": {
"id": "LGHEVUWWiB4f"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Draw the loss function"
],
"metadata": {
"id": "hr3hs7pKwo0g"
}
},
{
"cell_type": "code",
"source": [
"def draw_loss_function(compute_loss, X, y, model, phi_iters):\n",
" # Define pretty colormap\n",
" my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n",
" my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n",
" r = np.floor(my_colormap_vals_dec/(256*256))\n",
" g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n",
" b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
" my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n",
"\n",
" # Make grid of intercept/slope values to plot\n",
" intercepts_mesh, slopes_mesh = np.meshgrid(np.arange(-1.0,1.0,0.005), np.arange(-1.0,1.0,0.005))\n",
" loss_mesh = np.zeros_like(slopes_mesh)\n",
" # Compute loss for every set of parameters\n",
" for idslope, slope in np.ndenumerate(slopes_mesh):\n",
" loss_mesh[idslope] = compute_loss(X, y, model, np.array([[intercepts_mesh[idslope]], [slope]]))\n",
"\n",
" fig,ax = plt.subplots()\n",
" fig.set_size_inches(8,8)\n",
" ax.contourf(intercepts_mesh,slopes_mesh,loss_mesh,256,cmap=my_colormap)\n",
" ax.contour(intercepts_mesh,slopes_mesh,loss_mesh,40,colors=['#80808080'])\n",
" ax.set_ylim([1,-1]); ax.set_xlim([-1,1])\n",
"\n",
" ax.plot(phi_iters[1,:], phi_iters[0,:],'g-')\n",
" ax.set_xlabel('Intercept'); ax.set_ylabel('Slope')\n",
" plt.show()"
],
"metadata": {
"id": "UCxa3tZ8a9kz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"draw_loss_function(compute_loss, X[0:1,:], y.T, model, phi_t_all)"
],
"metadata": {
"id": "pXLLBaSaiI2A"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Draw the evolution of the function"
],
"metadata": {
"id": "ZsremHW-xFi5"
}
},
{
"cell_type": "code",
"source": [
"fig, ax = plt.subplots()\n",
"ax.plot(X[0:1,:],y.T,'ro')\n",
"x_vals = np.arange(0,1,0.001)\n",
"ax.plot(x_vals, phi_t_all[0,0]*x_vals + phi_t_all[1,0],'r-', label='t=0.00')\n",
"ax.plot(x_vals, phi_t_all[0,10]*x_vals + phi_t_all[1,10],'g-', label='t=0.10')\n",
"ax.plot(x_vals, phi_t_all[0,30]*x_vals + phi_t_all[1,30],'b-', label='t=0.30')\n",
"ax.plot(x_vals, phi_t_all[0,200]*x_vals + phi_t_all[1,200],'c-', label='t=2.00')\n",
"ax.plot(x_vals, phi_t_all[0,1999]*x_vals + phi_t_all[1,1999],'y-', label='t=20.0')\n",
"ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n",
"ax.set_xlabel('x'); ax.set_ylabel('y')\n",
"plt.legend(loc=\"upper left\")\n",
"plt.show()"
],
"metadata": {
"id": "cv9ZrUoRkuhI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute MAP and ML solutions\n",
"MLParams = np.linalg.inv(X@X.T)@X@y\n",
"sigma_sq_p = 3.0\n",
"sigma_sq = 0.05\n",
"MAPParams = np.linalg.inv(X@X.T+np.identity(X.shape[0])*sigma_sq/sigma_sq_p)@X@y"
],
"metadata": {
"id": "OU9oegSOof-o"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Finally, we predict both the mean and the uncertainty in the fitted model as a function of time"
],
"metadata": {
"id": "Ul__XvOgyYSA"
}
},
{
"cell_type": "code",
"source": [
"# Define x positions to make predictions (appending a 1 to each column)\n",
"x_predict = np.arange(0,1,0.01)[None,:]\n",
"x_predict = np.concatenate((x_predict,np.ones_like(x_predict)))\n",
"nX = x_predict.shape[1]\n",
"\n",
"# Create variables to store evolution of mean and variance of prediction over time\n",
"predict_mean_all = np.zeros((nT,nX))\n",
"predict_var_all = np.zeros((nT,nX))\n",
"\n",
"# Initial covariance\n",
"sigma_sq_p = 2.0\n",
"cov_init = sigma_sq_p * np.identity(2)\n",
"\n",
"# Run through each time computing a and b and hence mean and variance of prediction\n",
"for t in range(len(t_all)):\n",
" a = x_predict.T @(XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ y)\n",
" b = x_predict.T -x_predict.T@XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ X.T\n",
" predict_mean_all[t:t+1,:] = a.T\n",
" predict_cov = b@ cov_init @b.T\n",
" # We just want the diagonal of the covariance to plot the uncertainty\n",
" predict_var_all[t:t+1,:] = np.reshape(np.diag(predict_cov),(1,nX))"
],
"metadata": {
"id": "aMPADCuByKWr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Plot the mean and variance at various times"
],
"metadata": {
"id": "PZTj93KK7QH6"
}
},
{
"cell_type": "code",
"source": [
"def plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t, sigma_sq = 0.00001):\n",
" fig, ax = plt.subplots()\n",
" ax.plot(X[0:1,:],y.T,'ro')\n",
" ax.plot(x_predict[0:1,:].T, predict_mean_all[this_t:this_t+1,:].T,'r-')\n",
" lower = np.squeeze(predict_mean_all[this_t:this_t+1,:].T-np.sqrt(predict_var_all[this_t:this_t+1,:].T+np.sqrt(sigma_sq)))\n",
" upper = np.squeeze(predict_mean_all[this_t:this_t+1,:].T+np.sqrt(predict_var_all[this_t:this_t+1,:].T+np.sqrt(sigma_sq)))\n",
" ax.fill_between(np.squeeze(x_predict[0:1,:]), lower, upper, color='lightgray')\n",
" ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n",
" ax.set_xlabel('x'); ax.set_ylabel('y')\n",
" plt.show()\n",
"\n",
"plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=0)\n",
"plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=40)\n",
"plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=80)\n",
"plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=200)\n",
"plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=500)\n",
"plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=1000)"
],
"metadata": {
"id": "bYAFxgB880-v"
},
"execution_count": null,
"outputs": []
}
]
}

1109
Blogs/BorealisNTK.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -185,7 +185,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Return probability under normal distribution for input x\n", "# Return probability under normal distribution\n",
"def normal_distribution(y, mu, sigma):\n", "def normal_distribution(y, mu, sigma):\n",
" # TODO-- write in the equation for the normal distribution\n", " # TODO-- write in the equation for the normal distribution\n",
" # Equation 5.7 from the notes (you will need np.sqrt() and np.exp(), and math.pi)\n", " # Equation 5.7 from the notes (you will need np.sqrt() and np.exp(), and math.pi)\n",
@@ -329,7 +329,7 @@
"mu_pred = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n", "mu_pred = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n",
"# Set the standard deviation to something reasonable\n", "# Set the standard deviation to something reasonable\n",
"sigma = 0.2\n", "sigma = 0.2\n",
"# Compute the log likelihood\n", "# Compute the negative log likelihood\n",
"nll = compute_negative_log_likelihood(y_train, mu_pred, sigma)\n", "nll = compute_negative_log_likelihood(y_train, mu_pred, sigma)\n",
"# Let's double check we get the right answer before proceeding\n", "# Let's double check we get the right answer before proceeding\n",
"print(\"Correct answer = %9.9f, Your answer = %9.9f\"%(11.452419564,nll))" "print(\"Correct answer = %9.9f, Your answer = %9.9f\"%(11.452419564,nll))"
@@ -388,7 +388,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"Now let's investigate finding the maximum likelihood / minimum log likelihood / least squares solution. For simplicity, we'll assume that all the parameters are correct except one and look at how the likelihood, log likelihood, and sum of squares change as we manipulate the last parameter. We'll start with overall y offset, beta_1 (formerly phi_0)" "Now let's investigate finding the maximum likelihood / minimum negative log likelihood / least squares solution. For simplicity, we'll assume that all the parameters are correct except one and look at how the likelihood, negative log likelihood, and sum of squares change as we manipulate the last parameter. We'll start with overall y offset, beta_1 (formerly phi_0)"
], ],
"metadata": { "metadata": {
"id": "OgcRojvPWh4V" "id": "OgcRojvPWh4V"
@@ -431,7 +431,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Now let's plot the likelihood, negative log likelihood, and least squares as a function the value of the offset beta1\n", "# Now let's plot the likelihood, negative log likelihood, and least squares as a function of the value of the offset beta1\n",
"fig, ax = plt.subplots(1,2)\n", "fig, ax = plt.subplots(1,2)\n",
"fig.set_size_inches(10.5, 5.5)\n", "fig.set_size_inches(10.5, 5.5)\n",
"fig.tight_layout(pad=10.0)\n", "fig.tight_layout(pad=10.0)\n",
@@ -530,7 +530,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Now let's plot the likelihood, negative log likelihood, and least squares as a function the value of the standard divation sigma\n", "# Now let's plot the likelihood, negative log likelihood, and least squares as a function of the value of the standard deviation sigma\n",
"fig, ax = plt.subplots(1,2)\n", "fig, ax = plt.subplots(1,2)\n",
"fig.set_size_inches(10.5, 5.5)\n", "fig.set_size_inches(10.5, 5.5)\n",
"fig.tight_layout(pad=10.0)\n", "fig.tight_layout(pad=10.0)\n",
@@ -581,7 +581,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"Obviously, to fit the full neural model we would vary all of the 10 parameters of the network in $\\boldsymbol\\beta_{0},\\boldsymbol\\omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\omega_{1}$ (and maybe $\\sigma$) until we find the combination that have the maximum likelihood / minimum negative log likelihood / least squares.<br><br>\n", "Obviously, to fit the full neural model we would vary all of the 10 parameters of the network in $\\boldsymbol\\beta_{0},\\boldsymbol\\Omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\Omega_{1}$ (and maybe $\\sigma$) until we find the combination that have the maximum likelihood / minimum negative log likelihood / least squares.<br><br>\n",
"\n", "\n",
"Here we just varied one at a time as it is easier to see what is going on. This is known as **coordinate descent**.\n" "Here we just varied one at a time as it is easier to see what is going on. This is known as **coordinate descent**.\n"
], ],

View File

@@ -4,7 +4,6 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"authorship_tag": "ABX9TyOSb+W2AOFVQm8FZcHAb2Jq",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@@ -199,7 +198,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"The left is model output and the right is the model output after the sigmoid has been applied, so it now lies in the range [0,1] and represents the probability, that y=1. The black dots show the training data. We'll compute the the likelihood and the negative log likelihood." "The left is model output and the right is the model output after the sigmoid has been applied, so it now lies in the range [0,1] and represents the probability, that y=1. The black dots show the training data. We'll compute the likelihood and the negative log likelihood."
], ],
"metadata": { "metadata": {
"id": "MvVX6tl9AEXF" "id": "MvVX6tl9AEXF"
@@ -208,7 +207,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Return probability under Bernoulli distribution for input x\n", "# Return probability under Bernoulli distribution for observed class y\n",
"def bernoulli_distribution(y, lambda_param):\n", "def bernoulli_distribution(y, lambda_param):\n",
" # TODO-- write in the equation for the Bernoulli distribution\n", " # TODO-- write in the equation for the Bernoulli distribution\n",
" # Equation 5.17 from the notes (you will need np.power)\n", " # Equation 5.17 from the notes (you will need np.power)\n",
@@ -269,7 +268,7 @@
"source": [ "source": [
"# Let's test this\n", "# Let's test this\n",
"beta_0, omega_0, beta_1, omega_1 = get_parameters()\n", "beta_0, omega_0, beta_1, omega_1 = get_parameters()\n",
"# Use our neural network to predict the mean of the Gaussian\n", "# Use our neural network to predict the Bernoulli parameter lambda\n",
"model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n", "model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n",
"lambda_train = sigmoid(model_out)\n", "lambda_train = sigmoid(model_out)\n",
"# Compute the likelihood\n", "# Compute the likelihood\n",
@@ -336,7 +335,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"Now let's investigate finding the maximum likelihood / minimum negative log likelihood solution. For simplicity, we'll assume that all the parameters are fixed except one and look at how the likelihood and log likelihood change as we manipulate the last parameter. We'll start with overall y_offset, beta_1 (formerly phi_0)" "Now let's investigate finding the maximum likelihood / minimum negative log likelihood solution. For simplicity, we'll assume that all the parameters are fixed except one and look at how the likelihood and negative log likelihood change as we manipulate the last parameter. We'll start with overall y_offset, beta_1 (formerly phi_0)"
], ],
"metadata": { "metadata": {
"id": "OgcRojvPWh4V" "id": "OgcRojvPWh4V"
@@ -359,7 +358,7 @@
" # Run the network with new parameters\n", " # Run the network with new parameters\n",
" model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n", " model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n",
" lambda_train = sigmoid(model_out)\n", " lambda_train = sigmoid(model_out)\n",
" # Compute and store the three values\n", " # Compute and store the two values\n",
" likelihoods[count] = compute_likelihood(y_train,lambda_train)\n", " likelihoods[count] = compute_likelihood(y_train,lambda_train)\n",
" nlls[count] = compute_negative_log_likelihood(y_train, lambda_train)\n", " nlls[count] = compute_negative_log_likelihood(y_train, lambda_train)\n",
" # Draw the model for every 20th parameter setting\n", " # Draw the model for every 20th parameter setting\n",
@@ -378,7 +377,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Now let's plot the likelihood, negative log likelihood, and least squares as a function the value of the offset beta1\n", "# Now let's plot the likelihood and negative log likelihood as a function of the value of the offset beta1\n",
"fig, ax = plt.subplots()\n", "fig, ax = plt.subplots()\n",
"fig.tight_layout(pad=5.0)\n", "fig.tight_layout(pad=5.0)\n",
"likelihood_color = 'tab:red'\n", "likelihood_color = 'tab:red'\n",
@@ -430,7 +429,7 @@
"source": [ "source": [
"They both give the same answer. But you can see from the likelihood above that the likelihood is very small unless the parameters are almost correct. So in practice, we would work with the negative log likelihood.<br><br>\n", "They both give the same answer. But you can see from the likelihood above that the likelihood is very small unless the parameters are almost correct. So in practice, we would work with the negative log likelihood.<br><br>\n",
"\n", "\n",
"Again, to fit the full neural model we would vary all of the 10 parameters of the network in the $\\boldsymbol\\beta_{0},\\boldsymbol\\omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\omega_{1}$ until we find the combination that have the maximum likelihood / minimum negative log likelihood.<br><br>\n", "Again, to fit the full neural model we would vary all of the 10 parameters of the network in the $\\boldsymbol\\beta_{0},\\boldsymbol\\Omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\Omega_{1}$ until we find the combination that have the maximum likelihood / minimum negative log likelihood.<br><br>\n",
"\n" "\n"
], ],
"metadata": { "metadata": {

View File

@@ -1,18 +1,16 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "id": "view-in-github",
"id": "view-in-github" "colab_type": "text"
}, },
"source": [ "source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap05/5_3_Multiclass_Cross_entropy_Loss.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" "<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap05/5_3_Multiclass_Cross_entropy_Loss.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "jSlFkICHwHQF" "id": "jSlFkICHwHQF"
@@ -142,7 +140,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "PsgLZwsPxauP" "id": "PsgLZwsPxauP"
@@ -209,13 +206,12 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "MvVX6tl9AEXF" "id": "MvVX6tl9AEXF"
}, },
"source": [ "source": [
"The left is model output and the right is the model output after the softmax has been applied, so it now lies in the range [0,1] and represents the probability, that y=0 (red), 1 (green) and 2 (blue) The dots at the bottom show the training data with the same color scheme. So we want the red curve to be high where there are red dots, the green curve to be high where there are green dots, and the blue curve to be high where there are blue dots We'll compute the the likelihood and the negative log likelihood." "The left is model output and the right is the model output after the softmax has been applied, so it now lies in the range [0,1] and represents the probability, that y=0 (red), 1 (green) and 2 (blue). The dots at the bottom show the training data with the same color scheme. So we want the red curve to be high where there are red dots, the green curve to be high where there are green dots, and the blue curve to be high where there are blue dots We'll compute the the likelihood and the negative log likelihood."
] ]
}, },
{ {
@@ -226,7 +222,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Return probability under Categorical distribution for input x\n", "# Return probability under categorical distribution for observed class y\n",
"# Just take value from row k of lambda param where y =k,\n", "# Just take value from row k of lambda param where y =k,\n",
"def categorical_distribution(y, lambda_param):\n", "def categorical_distribution(y, lambda_param):\n",
" return np.array([lambda_param[row, i] for i, row in enumerate (y)])" " return np.array([lambda_param[row, i] for i, row in enumerate (y)])"
@@ -248,7 +244,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "R5z_0dzQMF35" "id": "R5z_0dzQMF35"
@@ -286,7 +281,7 @@
"source": [ "source": [
"# Let's test this\n", "# Let's test this\n",
"beta_0, omega_0, beta_1, omega_1 = get_parameters()\n", "beta_0, omega_0, beta_1, omega_1 = get_parameters()\n",
"# Use our neural network to predict the mean of the Gaussian\n", "# Use our neural network to predict the parameters of the categorical distribution\n",
"model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n", "model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n",
"lambda_train = softmax(model_out)\n", "lambda_train = softmax(model_out)\n",
"# Compute the likelihood\n", "# Compute the likelihood\n",
@@ -296,7 +291,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "HzphKgPfOvlk" "id": "HzphKgPfOvlk"
@@ -318,7 +312,7 @@
"source": [ "source": [
"# Return the negative log likelihood of the data under the model\n", "# Return the negative log likelihood of the data under the model\n",
"def compute_negative_log_likelihood(y_train, lambda_param):\n", "def compute_negative_log_likelihood(y_train, lambda_param):\n",
" # TODO -- compute the likelihood of the data -- don't use the likelihood function above -- compute the negative sum of the log probabilities\n", " # TODO -- compute the negative log likelihood of the data -- don't use the likelihood function above -- compute the negative sum of the log probabilities\n",
" # You will need np.sum(), np.log()\n", " # You will need np.sum(), np.log()\n",
" # Replace the line below\n", " # Replace the line below\n",
" nll = 0\n", " nll = 0\n",
@@ -336,24 +330,23 @@
"source": [ "source": [
"# Let's test this\n", "# Let's test this\n",
"beta_0, omega_0, beta_1, omega_1 = get_parameters()\n", "beta_0, omega_0, beta_1, omega_1 = get_parameters()\n",
"# Use our neural network to predict the mean of the Gaussian\n", "# Use our neural network to predict the parameters of the categorical distribution\n",
"model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n", "model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n",
"# Pass the outputs through the softmax function\n", "# Pass the outputs through the softmax function\n",
"lambda_train = softmax(model_out)\n", "lambda_train = softmax(model_out)\n",
"# Compute the log likelihood\n", "# Compute the negative log likelihood\n",
"nll = compute_negative_log_likelihood(y_train, lambda_train)\n", "nll = compute_negative_log_likelihood(y_train, lambda_train)\n",
"# Let's double check we get the right answer before proceeding\n", "# Let's double check we get the right answer before proceeding\n",
"print(\"Correct answer = %9.9f, Your answer = %9.9f\"%(17.015457867,nll))" "print(\"Correct answer = %9.9f, Your answer = %9.9f\"%(17.015457867,nll))"
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "OgcRojvPWh4V" "id": "OgcRojvPWh4V"
}, },
"source": [ "source": [
"Now let's investigate finding the maximum likelihood / minimum log likelihood solution. For simplicity, we'll assume that all the parameters are fixed except one and look at how the likelihood and log likelihood change as we manipulate the last parameter. We'll start with overall y_offset, $\\beta_1$ (formerly $\\phi_0$)" "Now let's investigate finding the maximum likelihood / minimum negative log likelihood solution. For simplicity, we'll assume that all the parameters are fixed except one and look at how the likelihood and negative log likelihood change as we manipulate the last parameter. We'll start with overall y_offset, $\\beta_1$ (formerly $\\phi_0$)"
] ]
}, },
{ {
@@ -378,7 +371,7 @@
" # Run the network with new parameters\n", " # Run the network with new parameters\n",
" model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n", " model_out = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n",
" lambda_train = softmax(model_out)\n", " lambda_train = softmax(model_out)\n",
" # Compute and store the three values\n", " # Compute and store the two values\n",
" likelihoods[count] = compute_likelihood(y_train,lambda_train)\n", " likelihoods[count] = compute_likelihood(y_train,lambda_train)\n",
" nlls[count] = compute_negative_log_likelihood(y_train, lambda_train)\n", " nlls[count] = compute_negative_log_likelihood(y_train, lambda_train)\n",
" # Draw the model for every 20th parameter setting\n", " # Draw the model for every 20th parameter setting\n",
@@ -397,7 +390,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Now let's plot the likelihood, negative log likelihood, and least squares as a function the value of the offset beta1\n", "# Now let's plot the likelihood and negative log likelihood as a function of the value of the offset beta1\n",
"fig, ax = plt.subplots()\n", "fig, ax = plt.subplots()\n",
"fig.tight_layout(pad=5.0)\n", "fig.tight_layout(pad=5.0)\n",
"likelihood_color = 'tab:red'\n", "likelihood_color = 'tab:red'\n",
@@ -440,7 +433,6 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "771G8N1Vk5A2" "id": "771G8N1Vk5A2"
@@ -448,16 +440,15 @@
"source": [ "source": [
"They both give the same answer. But you can see from the likelihood above that the likelihood is very small unless the parameters are almost correct. So in practice, we would work with the negative log likelihood.<br><br>\n", "They both give the same answer. But you can see from the likelihood above that the likelihood is very small unless the parameters are almost correct. So in practice, we would work with the negative log likelihood.<br><br>\n",
"\n", "\n",
"Again, to fit the full neural model we would vary all of the 16 parameters of the network in the $\\boldsymbol\\beta_{0},\\boldsymbol\\omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\omega_{1}$ until we find the combination that have the maximum likelihood / minimum negative log likelihood.<br><br>\n", "Again, to fit the full neural model we would vary all of the 16 parameters of the network in the $\\boldsymbol\\beta_{0},\\boldsymbol\\Omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\Omega_{1}$ until we find the combination that have the maximum likelihood / minimum negative log likelihood.<br><br>\n",
"\n" "\n"
] ]
} }
], ],
"metadata": { "metadata": {
"colab": { "colab": {
"authorship_tag": "ABX9TyOPv/l+ToaApJV7Nz+8AtpV", "provenance": [],
"include_colab_link": true, "include_colab_link": true
"provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",

View File

@@ -113,7 +113,7 @@
" b = 0.33\n", " b = 0.33\n",
" c = 0.66\n", " c = 0.66\n",
" d = 1.0\n", " d = 1.0\n",
" n_iter =0;\n", " n_iter = 0\n",
"\n", "\n",
" # While we haven't found the minimum closely enough\n", " # While we haven't found the minimum closely enough\n",
" while np.abs(b-c) > thresh and n_iter < max_iter:\n", " while np.abs(b-c) > thresh and n_iter < max_iter:\n",
@@ -131,8 +131,7 @@
"\n", "\n",
" print('Iter %d, a=%3.3f, b=%3.3f, c=%3.3f, d=%3.3f'%(n_iter, a,b,c,d))\n", " print('Iter %d, a=%3.3f, b=%3.3f, c=%3.3f, d=%3.3f'%(n_iter, a,b,c,d))\n",
"\n", "\n",
" # Rule #1 If the HEIGHT at point A is less the HEIGHT at points B, C, and D then halve values of B, C, and D\n", " # Rule #1 If the HEIGHT at point A is less than the HEIGHT at points B, C, and D then halve values of B, C, and D\n",
" # i.e. bring them closer to the original point\n",
" # i.e. bring them closer to the original point\n", " # i.e. bring them closer to the original point\n",
" # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n", " # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n",
" if (0):\n", " if (0):\n",
@@ -140,7 +139,7 @@
"\n", "\n",
"\n", "\n",
" # Rule #2 If the HEIGHT at point b is less than the HEIGHT at point c then\n", " # Rule #2 If the HEIGHT at point b is less than the HEIGHT at point c then\n",
" # then point d becomes point c, and\n", " # point d becomes point c, and\n",
" # point b becomes 1/3 between a and new d\n", " # point b becomes 1/3 between a and new d\n",
" # point c becomes 2/3 between a and new d\n", " # point c becomes 2/3 between a and new d\n",
" # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n", " # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n",
@@ -148,7 +147,7 @@
" continue;\n", " continue;\n",
"\n", "\n",
" # Rule #3 If the HEIGHT at point c is less than the HEIGHT at point b then\n", " # Rule #3 If the HEIGHT at point c is less than the HEIGHT at point b then\n",
" # then point a becomes point b, and\n", " # point a becomes point b, and\n",
" # point b becomes 1/3 between new a and d\n", " # point b becomes 1/3 between new a and d\n",
" # point c becomes 2/3 between new a and d\n", " # point c becomes 2/3 between new a and d\n",
" # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n", " # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n",

View File

@@ -117,7 +117,7 @@
"id": "QU5mdGvpTtEG" "id": "QU5mdGvpTtEG"
}, },
"source": [ "source": [
"Now lets create compute the sum of squares loss for the training data" "Now let's compute the sum of squares loss for the training data"
] ]
}, },
{ {
@@ -317,7 +317,7 @@
" b = 0.33 * max_dist\n", " b = 0.33 * max_dist\n",
" c = 0.66 * max_dist\n", " c = 0.66 * max_dist\n",
" d = 1.0 * max_dist\n", " d = 1.0 * max_dist\n",
" n_iter =0;\n", " n_iter = 0\n",
"\n", "\n",
" # While we haven't found the minimum closely enough\n", " # While we haven't found the minimum closely enough\n",
" while np.abs(b-c) > thresh and n_iter < max_iter:\n", " while np.abs(b-c) > thresh and n_iter < max_iter:\n",
@@ -341,7 +341,7 @@
" continue;\n", " continue;\n",
"\n", "\n",
" # Rule #2 If point b is less than point c then\n", " # Rule #2 If point b is less than point c then\n",
" # then point d becomes point c, and\n", " # point d becomes point c, and\n",
" # point b becomes 1/3 between a and new d\n", " # point b becomes 1/3 between a and new d\n",
" # point c becomes 2/3 between a and new d\n", " # point c becomes 2/3 between a and new d\n",
" if lossb < lossc:\n", " if lossb < lossc:\n",
@@ -351,7 +351,7 @@
" continue\n", " continue\n",
"\n", "\n",
" # Rule #2 If point c is less than point b then\n", " # Rule #2 If point c is less than point b then\n",
" # then point a becomes point b, and\n", " # point a becomes point b, and\n",
" # point b becomes 1/3 between new a and d\n", " # point b becomes 1/3 between new a and d\n",
" # point c becomes 2/3 between new a and d\n", " # point c becomes 2/3 between new a and d\n",
" a = b\n", " a = b\n",

View File

@@ -53,7 +53,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Let's create our training data 30 pairs {x_i, y_i}\n", "# Let's create our training data of 30 pairs {x_i, y_i}\n",
"# We'll try to fit the Gabor model to these data\n", "# We'll try to fit the Gabor model to these data\n",
"data = np.array([[-1.920e+00,-1.422e+01,1.490e+00,-1.940e+00,-2.389e+00,-5.090e+00,\n", "data = np.array([[-1.920e+00,-1.422e+01,1.490e+00,-1.940e+00,-2.389e+00,-5.090e+00,\n",
" -8.861e+00,3.578e+00,-6.010e+00,-6.995e+00,3.634e+00,8.743e-01,\n", " -8.861e+00,3.578e+00,-6.010e+00,-6.995e+00,3.634e+00,8.743e-01,\n",
@@ -128,7 +128,7 @@
"id": "QU5mdGvpTtEG" "id": "QU5mdGvpTtEG"
}, },
"source": [ "source": [
"Now lets create compute the sum of squares loss for the training data" "Now let's compute the sum of squares loss for the training data"
] ]
}, },
{ {
@@ -198,7 +198,7 @@
" b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", " b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
" my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n", " my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n",
"\n", "\n",
" # Make grid of intercept/slope values to plot\n", " # Make grid of offset/frequency values to plot\n",
" offsets_mesh, freqs_mesh = np.meshgrid(np.arange(-10,10.0,0.1), np.arange(2.5,22.5,0.1))\n", " offsets_mesh, freqs_mesh = np.meshgrid(np.arange(-10,10.0,0.1), np.arange(2.5,22.5,0.1))\n",
" loss_mesh = np.zeros_like(freqs_mesh)\n", " loss_mesh = np.zeros_like(freqs_mesh)\n",
" # Compute loss for every set of parameters\n", " # Compute loss for every set of parameters\n",
@@ -343,7 +343,7 @@
" b = 0.33 * max_dist\n", " b = 0.33 * max_dist\n",
" c = 0.66 * max_dist\n", " c = 0.66 * max_dist\n",
" d = 1.0 * max_dist\n", " d = 1.0 * max_dist\n",
" n_iter =0;\n", " n_iter = 0\n",
"\n", "\n",
" # While we haven't found the minimum closely enough\n", " # While we haven't found the minimum closely enough\n",
" while np.abs(b-c) > thresh and n_iter < max_iter:\n", " while np.abs(b-c) > thresh and n_iter < max_iter:\n",
@@ -367,7 +367,7 @@
" continue;\n", " continue;\n",
"\n", "\n",
" # Rule #2 If point b is less than point c then\n", " # Rule #2 If point b is less than point c then\n",
" # then point d becomes point c, and\n", " # point d becomes point c, and\n",
" # point b becomes 1/3 between a and new d\n", " # point b becomes 1/3 between a and new d\n",
" # point c becomes 2/3 between a and new d\n", " # point c becomes 2/3 between a and new d\n",
" if lossb < lossc:\n", " if lossb < lossc:\n",
@@ -377,7 +377,7 @@
" continue\n", " continue\n",
"\n", "\n",
" # Rule #2 If point c is less than point b then\n", " # Rule #2 If point c is less than point b then\n",
" # then point a becomes point b, and\n", " # point a becomes point b, and\n",
" # point b becomes 1/3 between new a and d\n", " # point b becomes 1/3 between new a and d\n",
" # point c becomes 2/3 between new a and d\n", " # point c becomes 2/3 between new a and d\n",
" a = b\n", " a = b\n",

View File

@@ -61,7 +61,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Let's create our training data 30 pairs {x_i, y_i}\n", "# Let's create our training data of 30 pairs {x_i, y_i}\n",
"# We'll try to fit the Gabor model to these data\n", "# We'll try to fit the Gabor model to these data\n",
"data = np.array([[-1.920e+00,-1.422e+01,1.490e+00,-1.940e+00,-2.389e+00,-5.090e+00,\n", "data = np.array([[-1.920e+00,-1.422e+01,1.490e+00,-1.940e+00,-2.389e+00,-5.090e+00,\n",
" -8.861e+00,3.578e+00,-6.010e+00,-6.995e+00,3.634e+00,8.743e-01,\n", " -8.861e+00,3.578e+00,-6.010e+00,-6.995e+00,3.634e+00,8.743e-01,\n",
@@ -137,7 +137,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"Now lets compute the sum of squares loss for the training data and plot the loss function" "Now let's compute the sum of squares loss for the training data and plot the loss function"
], ],
"metadata": { "metadata": {
"id": "QU5mdGvpTtEG" "id": "QU5mdGvpTtEG"
@@ -160,7 +160,7 @@
" b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", " b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
" my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n", " my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n",
"\n", "\n",
" # Make grid of intercept/slope values to plot\n", " # Make grid of offset/frequency values to plot\n",
" offsets_mesh, freqs_mesh = np.meshgrid(np.arange(-10,10.0,0.1), np.arange(2.5,22.5,0.1))\n", " offsets_mesh, freqs_mesh = np.meshgrid(np.arange(-10,10.0,0.1), np.arange(2.5,22.5,0.1))\n",
" loss_mesh = np.zeros_like(freqs_mesh)\n", " loss_mesh = np.zeros_like(freqs_mesh)\n",
" # Compute loss for every set of parameters\n", " # Compute loss for every set of parameters\n",
@@ -365,7 +365,6 @@
"\n", "\n",
" # Update the parameters\n", " # Update the parameters\n",
" phi_all[:,c_step+1:c_step+2] = phi_all[:,c_step:c_step+1] - alpha * momentum\n", " phi_all[:,c_step+1:c_step+2] = phi_all[:,c_step:c_step+1] - alpha * momentum\n",
" # Measure loss and draw model every 8th step\n",
"\n", "\n",
"loss = compute_loss(data[0,:], data[1,:], model, phi_all[:,c_step+1:c_step+2])\n", "loss = compute_loss(data[0,:], data[1,:], model, phi_all[:,c_step+1:c_step+2])\n",
"draw_model(data,model,phi_all[:,c_step+1], \"Iteration %d, loss = %f\"%(c_step+1,loss))\n", "draw_model(data,model,phi_all[:,c_step+1], \"Iteration %d, loss = %f\"%(c_step+1,loss))\n",

View File

@@ -4,7 +4,6 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"authorship_tag": "ABX9TyNFsCOnucz1nQt7PBEnKeTV",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@@ -110,7 +109,7 @@
" ax.plot(opt_path[0,:], opt_path[1,:],'-', color='#a0d9d3ff')\n", " ax.plot(opt_path[0,:], opt_path[1,:],'-', color='#a0d9d3ff')\n",
" ax.plot(opt_path[0,:], opt_path[1,:],'.', color='#a0d9d3ff',markersize=10)\n", " ax.plot(opt_path[0,:], opt_path[1,:],'.', color='#a0d9d3ff',markersize=10)\n",
" ax.set_xlabel(\"$\\phi_{0}$\")\n", " ax.set_xlabel(\"$\\phi_{0}$\")\n",
" ax.set_ylabel(\"$\\phi_1}$\")\n", " ax.set_ylabel(\"$\\phi_{1}$\")\n",
" plt.show()" " plt.show()"
], ],
"metadata": { "metadata": {
@@ -169,7 +168,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"Because the function changes much faster in $\\phi_1$ than in $\\phi_0$, there is no great step size to choose. If we set the step size so that it makes sensible progress in the $\\phi_1$, then it takes many iterations to converge. If we set the step size tso that we make sensible progress in the $\\phi_{0}$ direction, then the path oscillates in the $\\phi_1$ direction. \n", "Because the function changes much faster in $\\phi_1$ than in $\\phi_0$, there is no great step size to choose. If we set the step size so that it makes sensible progress in the $\\phi_1$ direction, then it takes many iterations to converge. If we set the step size so that we make sensible progress in the $\\phi_0$ direction, then the path oscillates in the $\\phi_1$ direction. \n",
"\n", "\n",
"This motivates Adam. At the core of Adam is the idea that we should just determine which way is downhill along each axis (i.e. left/right for $\\phi_0$ or up/down for $\\phi_1$) and move a fixed distance in that direction." "This motivates Adam. At the core of Adam is the idea that we should just determine which way is downhill along each axis (i.e. left/right for $\\phi_0$ or up/down for $\\phi_1$) and move a fixed distance in that direction."
], ],

View File

@@ -268,7 +268,7 @@
"mean_model, std_model = get_model_mean_variance(n_data, n_datasets, n_hidden, sigma_func) ;\n", "mean_model, std_model = get_model_mean_variance(n_data, n_datasets, n_hidden, sigma_func) ;\n",
"\n", "\n",
"# Plot the results\n", "# Plot the results\n",
"plot_function(x_func, y_func, x_data,y_data, x_model, mean_model, sigma_model=std_model)" "plot_function(x_func, y_func, x_model=x_model, y_model=mean_model, sigma_model=std_model)"
], ],
"metadata": { "metadata": {
"id": "Wxk64t2SoX9c" "id": "Wxk64t2SoX9c"

View File

@@ -310,7 +310,7 @@
"grad_path_tiny_lr = None ;\n", "grad_path_tiny_lr = None ;\n",
"\n", "\n",
"\n", "\n",
"# TODO: Run the gradient descent on the modified loss\n", "# TODO: Run the gradient descent on the unmodified loss\n",
"# function with 100 steps and a very small learning rate alpha of 0.05\n", "# function with 100 steps and a very small learning rate alpha of 0.05\n",
"# Replace this line:\n", "# Replace this line:\n",
"grad_path_typical_lr = None ;\n", "grad_path_typical_lr = None ;\n",

View File

@@ -4,7 +4,7 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"authorship_tag": "ABX9TyMrF4rB2hTKq7XzLuYsURdL", "authorship_tag": "ABX9TyMLKg5ZmXqojcVrZD5BGm9g",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@@ -235,7 +235,7 @@
"# Finite difference calculation\n", "# Finite difference calculation\n",
"dydx_fd = (y2-y1)/delta\n", "dydx_fd = (y2-y1)/delta\n",
"\n", "\n",
"print(\"Gradient calculation=%f, Finite difference gradient=%f\"%(dydx,dydx_fd))\n" "print(\"Gradient calculation=%f, Finite difference gradient=%f\"%(dydx.squeeze(),dydx_fd.squeeze()))\n"
], ],
"metadata": { "metadata": {
"id": "KJpQPVd36Haq" "id": "KJpQPVd36Haq"

View File

@@ -4,7 +4,7 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"authorship_tag": "ABX9TyOdSkjfQnSZXnffGsZVM7r5", "authorship_tag": "ABX9TyO/wJ4N9w01f04mmrs/ZSHY",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@@ -185,10 +185,10 @@
"np.set_printoptions(precision=3)\n", "np.set_printoptions(precision=3)\n",
"output = graph_attention(X, omega, beta, phi, A);\n", "output = graph_attention(X, omega, beta, phi, A);\n",
"print(\"Correct answer is:\")\n", "print(\"Correct answer is:\")\n",
"print(\"[[1.796 1.346 0.569 1.703 1.298 1.224 1.24 1.234]\")\n", "print(\"[[0. 0.028 0.37 0. 0.97 0. 0. 0.698]\")\n",
"print(\" [0.768 0.672 0. 0.529 3.841 4.749 5.376 4.761]\")\n", "print(\" [0. 0. 0. 0. 1.184 0. 2.654 0. ]\")\n",
"print(\" [0.305 0.129 0. 0.341 0.785 1.014 1.113 1.024]\")\n", "print(\" [1.13 0.564 0. 1.298 0.268 0. 0. 0.779]\")\n",
"print(\" [0. 0. 0. 0. 0.35 0.864 1.098 0.871]]]\")\n", "print(\" [0.825 0. 0. 1.175 0. 0. 0. 0. ]]]\")\n",
"\n", "\n",
"\n", "\n",
"print(\"Your answer is:\")\n", "print(\"Your answer is:\")\n",

View File

@@ -4,7 +4,6 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"authorship_tag": "ABX9TyNyLnpoXgKN+RGCuTUszCAZ",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@@ -153,9 +152,9 @@
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# TODO: Now construct the matrix A that has the initial distribution constraints\n", "# TODO: Now construct the matrix A that has the initial distribution constraints\n",
"# so that Ap=b where p is the transport plan P vectorized rows first so p = np.flatten(P)\n", "# so that A @ TPFlat=b where TPFlat is the transport plan TP vectorized rows first so TPFlat = np.flatten(TP)\n",
"# Replace this line:\n", "# Replace this line:\n",
"A = np.zeros((20,100))\n" "A = np.zeros((20,100))"
], ],
"metadata": { "metadata": {
"id": "7KrybL96IuNW" "id": "7KrybL96IuNW"
@@ -197,8 +196,8 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"P = np.array(opt.x).reshape(10,10)\n", "TP = np.array(opt.x).reshape(10,10)\n",
"draw_2D_heatmap(P,'Transport plan $\\mathbf{P}$', my_colormap)" "draw_2D_heatmap(TP,'Transport plan $\\mathbf{P}$', my_colormap)"
], ],
"metadata": { "metadata": {
"id": "nZGfkrbRV_D0" "id": "nZGfkrbRV_D0"
@@ -218,7 +217,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"was = np.sum(P * dist_mat)\n", "was = np.sum(TP * dist_mat)\n",
"print(\"Wasserstein distance = \", was)" "print(\"Wasserstein distance = \", was)"
], ],
"metadata": { "metadata": {

View File

@@ -387,7 +387,7 @@
"def compute_expectation2b(n_samples):\n", "def compute_expectation2b(n_samples):\n",
" # TODO -- complete this function\n", " # TODO -- complete this function\n",
" # 1. Draw n_samples from auxiliary distribution\n", " # 1. Draw n_samples from auxiliary distribution\n",
" # 2. Compute f[y] for those samples\n", " # 2. Compute f2[y] for those samples\n",
" # 3. Scale the results by pr_y / q_y\n", " # 3. Scale the results by pr_y / q_y\n",
" # 4. Compute the mean of these weighted samples\n", " # 4. Compute the mean of these weighted samples\n",
" # Replace this line\n", " # Replace this line\n",

View File

@@ -3,8 +3,8 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "id": "view-in-github",
"id": "view-in-github" "colab_type": "text"
}, },
"source": [ "source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap18/18_1_Diffusion_Encoder.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" "<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap18/18_1_Diffusion_Encoder.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
@@ -409,7 +409,7 @@
" # 3. Compute pdf of this Gaussian at every x_plot_val\n", " # 3. Compute pdf of this Gaussian at every x_plot_val\n",
" # 4. Weight Gaussian by probability at position x and by 0.01 to componensate for bin size\n", " # 4. Weight Gaussian by probability at position x and by 0.01 to componensate for bin size\n",
" # 5. Accumulate weighted Gaussian in marginal at time t.\n", " # 5. Accumulate weighted Gaussian in marginal at time t.\n",
" # 6. Multiply result by 0.01 to compensate for bin size\n", "\n",
" # Replace this line:\n", " # Replace this line:\n",
" marginal_at_time_t = marginal_at_time_t\n", " marginal_at_time_t = marginal_at_time_t\n",
"\n", "\n",
@@ -454,9 +454,8 @@
], ],
"metadata": { "metadata": {
"colab": { "colab": {
"authorship_tag": "ABX9TyMpC8kgLnXx0XQBtwNAQ4jJ", "provenance": [],
"include_colab_link": true, "include_colab_link": true
"provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",

View File

@@ -1,20 +1,4 @@
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMWjsdr5SDwyzcDftnehlNo",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
@@ -28,6 +12,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "t9vk9Elugvmi"
},
"source": [ "source": [
"# **Notebook 19.3: Monte-Carlo methods**\n", "# **Notebook 19.3: Monte-Carlo methods**\n",
"\n", "\n",
@@ -37,42 +24,49 @@
"\n", "\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n", "\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n",
], "\n",
"metadata": { "Thanks to [Akshil Patel](https://www.akshilpatel.com) and [Jessica Nicholson](https://jessicanicholson1.github.io) for their help in preparing this notebook."
"id": "t9vk9Elugvmi" ]
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image"
],
"metadata": { "metadata": {
"id": "OLComQyvCIJ7" "id": "OLComQyvCIJ7"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"\n",
"from IPython.display import clear_output\n",
"from time import sleep"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZsvrUszPLyEG"
},
"outputs": [],
"source": [ "source": [
"# Get local copies of components of images\n", "# Get local copies of components of images\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Empty.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Empty.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Hole.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Hole.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Fish.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Fish.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Penguin.png" "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Penguin.png"
], ]
"metadata": {
"id": "ZsvrUszPLyEG"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gq1HfJsHN3SB"
},
"outputs": [],
"source": [ "source": [
"# Ugly class that takes care of drawing pictures like in the book.\n", "# Ugly class that takes care of drawing pictures like in the book.\n",
"# You can totally ignore this code!\n", "# You can totally ignore this code!\n",
@@ -257,205 +251,281 @@
"\n", "\n",
"\n", "\n",
" plt.show()" " plt.show()"
], ]
"metadata": {
"id": "Gq1HfJsHN3SB"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eBQ7lTpJQBSe"
},
"outputs": [],
"source": [ "source": [
"# We're going to work on the problem depicted in figure 19.10a\n", "# We're going to work on the problem depicted in figure 19.10a\n",
"n_rows = 4; n_cols = 4\n", "n_rows = 4; n_cols = 4\n",
"layout = np.zeros(n_rows * n_cols)\n", "layout = np.zeros(n_rows * n_cols)\n",
"reward_structure = np.zeros(n_rows * n_cols)\n", "reward_structure = np.zeros(n_rows * n_cols)\n",
"layout[9] = 1 ; reward_structure[9] = -2\n", "layout[9] = 1 ; reward_structure[9] = -2 # Hole\n",
"layout[10] = 1; reward_structure[10] = -2\n", "layout[10] = 1; reward_structure[10] = -2 # Hole\n",
"layout[14] = 1; reward_structure[14] = -2\n", "layout[14] = 1; reward_structure[14] = -2 # Hole\n",
"layout[15] = 2; reward_structure[15] = 3\n", "layout[15] = 2; reward_structure[15] = 3 # Fish\n",
"initial_state = 0\n", "initial_state = 0\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, state = initial_state, rewards=reward_structure, draw_state_index = True)" "mdp_drawer.draw(layout, state = initial_state, rewards=reward_structure, draw_state_index = True)"
], ]
"metadata": {
"id": "eBQ7lTpJQBSe"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"For clarity, the black numbers are the state number and the red numbers are the reward for being in that state. Note that the states are indexed from 0 rather than 1 as in the book to make the code neater."
],
"metadata": { "metadata": {
"id": "6Vku6v_se2IG" "id": "6Vku6v_se2IG"
} },
"source": [
"For clarity, the black numbers are the state number and the red numbers are the reward for being in that state. Note that the states are indexed from 0 rather than 1 as in the book to make the code neater."
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "Fhc6DzZNOjiC"
},
"source": [ "source": [
"Now let's define the state transition function $Pr(s_{t+1}|s_{t},a)$ in full where $a$ is the actions. Here $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ right. However, the ice is slippery, so we don't always go the direction we want to.\n", "Now let's define the state transition function $Pr(s_{t+1}|s_{t},a)$ in full where $a$ is the actions. Here $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ right. However, the ice is slippery, so we don't always go the direction we want to.\n",
"\n", "\n",
"Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better" "Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better"
], ]
"metadata": {
"id": "Fhc6DzZNOjiC"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l7rT78BbOgTi"
},
"outputs": [],
"source": [ "source": [
"transition_probabilities_given_action0 = np.array(\\\n", "transition_probabilities_given_action0 = np.array(\\\n",
"[[0.00 , 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.90, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.85, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.33, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.85, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.90, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.34, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.34, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.75, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.75 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.10, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00]])\n",
"])\n", "\n",
"\n", "\n",
"transition_probabilities_given_action1 = np.array(\\\n", "transition_probabilities_given_action1 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.10, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.75 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.85, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.50, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.85, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.85, 0.90, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.25 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.25, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.10, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.75, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00]])\n",
"])\n", "\n",
"\n", "\n",
"transition_probabilities_given_action2 = np.array(\\\n", "transition_probabilities_given_action2 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.10, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.25 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.25, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.10, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.75 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.85, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.75, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.33, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.90, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.85, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.85, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00]])\n",
"])\n",
"\n", "\n",
"transition_probabilities_given_action3 = np.array(\\\n", "transition_probabilities_given_action3 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.90, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.05, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.50, 0.00, 0.75, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.05, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.10, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.00, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.90, 0.85, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.85, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.75 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00]])\n",
"])\n", "\n",
"\n",
"\n", "\n",
"# Store all of these in a three dimension array\n", "# Store all of these in a three dimension array\n",
"# Pr(s_{t+1}=2|s_{t}=1, a_{t}=3] is stored at position [2,1,3]\n", "# Pr(s_{t+1}=2|s_{t}=1, a_{t}=3] is stored at position [2,1,3]\n",
"transition_probabilities_given_action = np.concatenate((np.expand_dims(transition_probabilities_given_action0,2),\n", "transition_probabilities_given_action = np.concatenate((np.expand_dims(transition_probabilities_given_action0,2),\n",
" np.expand_dims(transition_probabilities_given_action1,2),\n", " np.expand_dims(transition_probabilities_given_action1,2),\n",
" np.expand_dims(transition_probabilities_given_action2,2),\n", " np.expand_dims(transition_probabilities_given_action2,2),\n",
" np.expand_dims(transition_probabilities_given_action3,2)),axis=2)" " np.expand_dims(transition_probabilities_given_action3,2)),axis=2)\n",
], "\n",
"metadata": { "print('Grid Size:', len(transition_probabilities_given_action[0]))\n",
"id": "l7rT78BbOgTi" "print()\n",
"print('Transition Probabilities for when next state = 2:')\n",
"print(transition_probabilities_given_action[2])\n",
"print()\n",
"print('Transitions Probabilities for when next state = 2 and current state = 1')\n",
"print(transition_probabilities_given_action[2][1])\n",
"print()\n",
"print('Transitions Probabilities for when next state = 2 and current state = 1 and action = 3 (Left):')\n",
"print(transition_probabilities_given_action[2][1][3])"
]
}, },
"execution_count": null, {
"outputs": [] "cell_type": "markdown",
"metadata": {
"id": "BHWjp6Qq4tBF"
},
"source": [
"## Implementation Details\n",
"\n",
"We provide the following methods:\n",
"\n",
"- **`markov_decision_process_step_stochastic`** - this function selects an action based on the stochastic policy for the current state, updates the state based on the transition probabilities associated with the chosen action, and returns the new state, the reward obtained for the new state, the chosen action, and whether the episode terminates.\n",
"\n",
"- **`get_one_episode`** - this function simulates an episode of agent-environment interaction. It returns the states, rewards, and actions seen in that episode, which we can then use to update the agent.\n",
"\n",
"- **`calculate_returns`** - this function calls on the **`calculate_return`** function that computes the discounted sum of rewards from a specific step, in a sequence of rewards.\n",
"\n",
"You have to implement the following methods:\n",
"\n",
"- **`deterministic_policy_to_epsilon_greedy`** - given a deterministic policy, where one action is chosen per state, this function computes the $\\epsilon$-greedy version of that policy, where each of the four actions has some nonzero probability of being selected per state. In each state, the probability of selecting each of the actions should sum to 1.\n",
"\n",
"- **`update_policy_mc`** - this function updates the action-value function using the Monte Carlo method. We use the rollout trajectories collected using `get_one_episode` to calculate the returns. Then update the action values towards the Monte Carlo estimate of the return for each state."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "akjrncMF-FkU"
},
"outputs": [],
"source": [ "source": [
"# This takes a single step from an MDP\n", "# This takes a single step from an MDP\n",
"def markov_decision_process_step_stochastic(state, transition_probabilities_given_action, reward_structure, stochastic_policy):\n", "def markov_decision_process_step_stochastic(state, transition_probabilities_given_action, reward_structure, terminal_states, stochastic_policy):\n",
" # Pick action\n", " # Pick action\n",
" action = np.random.choice(a=np.arange(0,4,1),p=stochastic_policy[:,state])\n", " action = np.random.choice(a=np.arange(0,4,1),p=stochastic_policy[:,state])\n",
"\n",
" # Update the state\n", " # Update the state\n",
" new_state = np.random.choice(a=np.arange(0,transition_probabilities_given_action.shape[0]),p = transition_probabilities_given_action[:,state,action])\n", " new_state = np.random.choice(a=np.arange(0,transition_probabilities_given_action.shape[0]),p = transition_probabilities_given_action[:,state,action])\n",
" # Return the reward\n", " # Return the reward\n",
" reward = reward_structure[new_state]\n", " reward = reward_structure[new_state]\n",
" is_terminal = new_state in [terminal_states]\n",
"\n", "\n",
" return new_state, reward, action" " return new_state, reward, action, is_terminal"
], ]
"metadata": {
"id": "akjrncMF-FkU"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"# Run one episode and return actions, rewards, returns\n",
"def get_one_episode(initial_state, transition_probabilities_given_action, reward_structure, stochastic_policy):\n",
"\n",
" max_steps = 1000\n",
" states = np.zeros(max_steps, dtype='uint8') ;\n",
" rewards = np.zeros(max_steps) ;\n",
" actions = np.zeros(max_steps, dtype='uint8') ;\n",
"\n",
" t = 0\n",
" states[t] = initial_state\n",
" # While haven't reached maximum number of steps\n",
" while t< max_steps:\n",
" # Keep stepping through MDP\n",
" states[t+1],rewards[t+1],actions[t] = markov_decision_process_step_stochastic(states[t], transition_probabilities_given_action, reward_structure, stochastic_policy)\n",
" # If we reach te:rminal state then quit\n",
" if states[t]==15:\n",
" break;\n",
" t+=1\n",
"\n",
" states = states[:t+1]\n",
" rewards = rewards[:t+1]\n",
" actions = actions[:t+1]\n",
"\n",
" return states, rewards, actions"
],
"metadata": { "metadata": {
"id": "bFYvF9nAloIA" "id": "bFYvF9nAloIA"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"# Run one episode and return actions, rewards, returns\n",
"def get_one_episode(initial_state, transition_probabilities_given_action, reward_structure, terminal_states, stochastic_policy):\n",
"\n",
" states = []\n",
" rewards = []\n",
" actions = []\n",
"\n",
" states.append(initial_state)\n",
" state = initial_state\n",
"\n",
" is_terminal = False\n",
" # While we haven't reached a terminal state\n",
" while not is_terminal:\n",
" # Keep stepping through MDP\n",
" state, reward, action, is_terminal = markov_decision_process_step_stochastic(state,\n",
" transition_probabilities_given_action,\n",
" reward_structure,\n",
" terminal_states,\n",
" stochastic_policy)\n",
" states.append(state)\n",
" rewards.append(reward)\n",
" actions.append(action)\n",
"\n",
" states = np.array(states, dtype=\"uint8\")\n",
" rewards = np.array(rewards)\n",
" actions = np.array(actions, dtype=\"uint8\")\n",
"\n",
" # If the episode was terminated early, we need to compute the return differently using r_{t+1} + gamma*V(s_{t+1})\n",
" return states, rewards, actions"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qJhOrIId4tBF"
},
"outputs": [],
"source": [
"def visualize_one_episode(states, actions):\n",
" # Define actions for visualization\n",
" acts = ['up', 'right', 'down', 'left']\n",
"\n",
" # Iterate over the states and actions\n",
" for i in range(len(states)):\n",
"\n",
" if i == 0:\n",
" print('Starting State:', states[i])\n",
"\n",
" elif i == len(states)-1:\n",
" print('Episode Done:', states[i])\n",
"\n",
" else:\n",
" print('State', states[i-1])\n",
" a = actions[i]\n",
" print('Action:', acts[a])\n",
" print('Next State:', states[i])\n",
"\n",
" # Visualize the current state using the MDP drawer\n",
" mdp_drawer.draw(layout, state=states[i], rewards=reward_structure, draw_state_index=True)\n",
" clear_output(True)\n",
"\n",
" # Pause for a short duration to allow observation\n",
" sleep(1.5)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_AKwdtQQHzIK"
},
"outputs": [],
"source": [ "source": [
"# Convert deterministic policy (1x16) to an epsilon greedy stochastic policy (4x16)\n", "# Convert deterministic policy (1x16) to an epsilon greedy stochastic policy (4x16)\n",
"def deterministic_policy_to_epsilon_greedy(policy, epsilon=0.1):\n", "def deterministic_policy_to_epsilon_greedy(policy, epsilon=0.2):\n",
" # TODO -- write this function\n", " # TODO -- write this function\n",
" # You should wind up with a 4x16 matrix, with epsilon/3 in every position except the real policy\n", " # You should wind up with a 4x16 matrix, with epsilon/3 in every position except the real policy\n",
" # The columns should sum to one\n", " # The columns should sum to one\n",
@@ -464,27 +534,27 @@
"\n", "\n",
"\n", "\n",
" return stochastic_policy" " return stochastic_policy"
], ]
"metadata": {
"id": "_AKwdtQQHzIK"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"Let's try generating an episode"
],
"metadata": { "metadata": {
"id": "OhVXw2Favo-w" "id": "OhVXw2Favo-w"
} },
"source": [
"Let's try generating an episode"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5zQ1Oh9Zvnwt"
},
"outputs": [],
"source": [ "source": [
"# Set seed so random numbers always the same\n", "# Set seed so random numbers always the same\n",
"np.random.seed(0)\n", "np.random.seed(6)\n",
"# Print in compact form\n", "# Print in compact form\n",
"np.set_printoptions(precision=3)\n", "np.set_printoptions(precision=3)\n",
"\n", "\n",
@@ -494,32 +564,55 @@
"# Convert deterministic policy to stochastic\n", "# Convert deterministic policy to stochastic\n",
"stochastic_policy = deterministic_policy_to_epsilon_greedy(policy)\n", "stochastic_policy = deterministic_policy_to_epsilon_greedy(policy)\n",
"\n", "\n",
"print(\"Initial policy:\")\n", "print(\"Initial Penguin Policy:\")\n",
"print(policy)\n", "print(policy)\n",
"print()\n",
"print('Stochastic Penguin Policy:')\n",
"print(stochastic_policy)\n",
"print()\n",
"\n", "\n",
"initial_state = 5\n", "initial_state = 5\n",
"states, rewards, actions = get_one_episode(initial_state,transition_probabilities_given_action, reward_structure, stochastic_policy)" "terminal_states=[15]\n",
], "states, rewards, actions = get_one_episode(initial_state,transition_probabilities_given_action, reward_structure, terminal_states, stochastic_policy)\n",
"metadata": { "\n",
"id": "5zQ1Oh9Zvnwt" "print('Initial Penguin Position:')\n",
}, "mdp_drawer.draw(layout, state = initial_state, rewards=reward_structure, draw_state_index = True)\n",
"execution_count": null, "\n",
"outputs": [] "print('Total steps to termination:', len(states))\n",
}, "print('Final Reward:', np.sum(rewards))"
{ ]
"cell_type": "markdown",
"source": [
"We'll need to calculate the returns (discounted cumulative reward) for each state action pair"
],
"metadata": {
"id": "nl6rtNffwhcU"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KJH-UGKk4tBF"
},
"outputs": [],
"source": [
"#this visualizes the complete episode\n",
"visualize_one_episode(states, actions)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nl6rtNffwhcU"
},
"source": [
"We'll need to calculate the returns (discounted cumulative reward) for each state action pair"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FxrItqGPLTq7"
},
"outputs": [],
"source": [ "source": [
"def calculate_returns(rewards, gamma):\n", "def calculate_returns(rewards, gamma):\n",
" returns = np.zeros_like(rewards)\n", " returns = np.zeros(len(rewards))\n",
" for c_return in range(len(returns)):\n", " for c_return in range(len(returns)):\n",
" returns[c_return] = calculate_return(rewards[c_return:], gamma)\n", " returns[c_return] = calculate_return(rewards[c_return:], gamma)\n",
" return returns\n", " return returns\n",
@@ -529,26 +622,26 @@
" for i in range(len(rewards)):\n", " for i in range(len(rewards)):\n",
" return_val += rewards[i] * np.power(gamma, i)\n", " return_val += rewards[i] * np.power(gamma, i)\n",
" return return_val" " return return_val"
], ]
"metadata": {
"id": "FxrItqGPLTq7"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"This routine does the main work of the Monte Carlo method. We repeatedly rollout episods, calculate the returns. Then we figure out the average return for each state action pair, and choose the next policy as the action that has greatest state action value at each state."
],
"metadata": { "metadata": {
"id": "DX1KfHRhzUOU" "id": "DX1KfHRhzUOU"
} },
"source": [
"This routine does the main work of the on-policy Monte Carlo method. We repeatedly rollout episods, calculate the returns. Then we figure out the average return for each state action pair, and choose the next policy as the action that has greatest state action value at each state."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hCghcKlOJXSM"
},
"outputs": [],
"source": [ "source": [
"def update_policy_mc(initial_state, transition_probabilities_given_action, reward_structure, stochastic_policy, gamma, n_rollouts=1):\n", "def update_policy_mc(initial_state, transition_probabilities_given_action, reward_structure, terminal_states, stochastic_policy, gamma, n_rollouts=1):\n",
" # Create two matrices to store total returns for each action/state pair and the\n", " # Create two matrices to store total returns for each action/state pair and the\n",
" # number of times we observed that action/state pair\n", " # number of times we observed that action/state pair\n",
" n_state = transition_probabilities_given_action.shape[0]\n", " n_state = transition_probabilities_given_action.shape[0]\n",
@@ -574,18 +667,18 @@
" state_action_values = state_action_returns_total/( state_action_count+0.00001)\n", " state_action_values = state_action_returns_total/( state_action_count+0.00001)\n",
" policy = np.argmax(state_action_values, axis=0).astype(int)\n", " policy = np.argmax(state_action_values, axis=0).astype(int)\n",
" return policy, state_action_values\n" " return policy, state_action_values\n"
], ]
"metadata": {
"id": "hCghcKlOJXSM"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8jWhDlkaKj7Q"
},
"outputs": [],
"source": [ "source": [
"# Set seed so random numbers always the same\n", "# Set seed so random numbers always the same\n",
"np.random.seed(3)\n", "np.random.seed(0)\n",
"# Print in compact form\n", "# Print in compact form\n",
"np.set_printoptions(precision=3)\n", "np.set_printoptions(precision=3)\n",
"\n", "\n",
@@ -597,32 +690,60 @@
"mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, rewards = reward_structure)\n", "mdp_drawer.draw(layout, policy = policy, rewards = reward_structure)\n",
"\n", "\n",
"\n", "terminal_states = [15]\n",
"n_policy_update = 5\n", "# Track all the policies so we can visualize them later\n",
"all_policies = []\n",
"n_policy_update = 2000\n",
"for c_policy_update in range(n_policy_update):\n", "for c_policy_update in range(n_policy_update):\n",
" # Convert policy to stochastic\n", " # Convert policy to stochastic\n",
" stochastic_policy = deterministic_policy_to_epsilon_greedy(policy)\n", " stochastic_policy = deterministic_policy_to_epsilon_greedy(policy)\n",
" # Update policy by Monte Carlo method\n", " # Update policy by Monte Carlo method\n",
" policy, state_action_values = update_policy_mc(initial_state, transition_probabilities_given_action, reward_structure, stochastic_policy, gamma, n_rollouts=1000)\n", " policy, state_action_values = update_policy_mc(initial_state, transition_probabilities_given_action, reward_structure, terminal_states, stochastic_policy, gamma, n_rollouts=100)\n",
" all_policies.append(policy)\n",
"\n",
" # Print out 10 snapshots of progress\n",
" if (c_policy_update % (n_policy_update//10) == 0) or c_policy_update == n_policy_update - 1:\n",
" print(\"Updated policy\")\n", " print(\"Updated policy\")\n",
" print(policy)\n", " print(policy)\n",
" mdp_drawer = DrawMDP(n_rows, n_cols)\n", " mdp_drawer = DrawMDP(n_rows, n_cols)\n",
" mdp_drawer.draw(layout, policy = policy, rewards = reward_structure, state_action_values=state_action_values)\n" " mdp_drawer.draw(layout, policy = policy, rewards = reward_structure, state_action_values=state_action_values)\n",
], "\n",
"metadata": { "\n"
"id": "8jWhDlkaKj7Q" ]
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"You can see that the results are quite noisy, but there is a definite improvement from the initial policy."
],
"metadata": { "metadata": {
"id": "j7Ny47kTEMzH" "id": "j7Ny47kTEMzH"
} },
} "source": [
"You can see a definite improvement to the policy"
] ]
}
],
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
} }

View File

@@ -1,20 +1,4 @@
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNEAhORON7DFN1dZMhDK/PO",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
@@ -28,6 +12,9 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "t9vk9Elugvmi"
},
"source": [ "source": [
"# **Notebook 19.4: Temporal difference methods**\n", "# **Notebook 19.4: Temporal difference methods**\n",
"\n", "\n",
@@ -35,42 +22,49 @@
"\n", "\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n", "\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n",
], "\n",
"metadata": { "Thanks to [Akshil Patel](https://www.akshilpatel.com) and [Jessica Nicholson](https://jessicanicholson1.github.io) for their help in preparing this notebook."
"id": "t9vk9Elugvmi" ]
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image"
],
"metadata": { "metadata": {
"id": "OLComQyvCIJ7" "id": "OLComQyvCIJ7"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"from IPython.display import clear_output\n",
"from time import sleep\n",
"from copy import deepcopy"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZsvrUszPLyEG"
},
"outputs": [],
"source": [ "source": [
"# Get local copies of components of images\n", "# Get local copies of components of images\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Empty.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Empty.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Hole.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Hole.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Fish.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Fish.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Penguin.png" "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Penguin.png"
], ]
"metadata": {
"id": "ZsvrUszPLyEG"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gq1HfJsHN3SB"
},
"outputs": [],
"source": [ "source": [
"# Ugly class that takes care of drawing pictures like in the book.\n", "# Ugly class that takes care of drawing pictures like in the book.\n",
"# You can totally ignore this code!\n", "# You can totally ignore this code!\n",
@@ -253,269 +247,516 @@
" self.draw_text(\"%2.2f\"%(state_action_values[3, c_cell]), np.floor(c_cell/self.n_col), c_cell-np.floor(c_cell/self.n_col)*self.n_col,'lc','black')\n", " self.draw_text(\"%2.2f\"%(state_action_values[3, c_cell]), np.floor(c_cell/self.n_col), c_cell-np.floor(c_cell/self.n_col)*self.n_col,'lc','black')\n",
"\n", "\n",
" plt.show()" " plt.show()"
], ]
"metadata": {
"id": "Gq1HfJsHN3SB"
}, },
"execution_count": null, {
"outputs": [] "cell_type": "markdown",
"metadata": {
"id": "JU8gX59o76xM"
},
"source": [
"# Penguin Ice Environment\n",
"\n",
"In this implementation we have designed an icy gridworld that a penguin has to traverse to reach the fish found in the bottom right corner.\n",
"\n",
"## Environment Description\n",
"\n",
"Consider having to cross an icy surface to reach the yummy fish. In order to achieve this task as quickly as possible, the penguin needs to waddle along as fast as it can whilst simultaneously avoiding falling into the holes.\n",
"\n",
"In this icy environment the penguin is at one of the discrete cells in the gridworld. The agent starts each episode on a randomly chosen cell. The environment state dynamics are captured by the transition probabilities $Pr(s_{t+1} |s_t, a_t)$ where $s_t$ is the current state, $a_t$ is the action chosen, and $s_{t+1}$ is the next state at decision stage t. At each decision stage, the penguin can move in one of four directions: $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ left.\n",
"\n",
"However, the ice is slippery, so we don't always go the direction we want to: every time the agent chooses an action, with 0.25 probability, the environment changes the action taken to a differenct action, which is uniformly sampled from the other available actions.\n",
"\n",
"The rewards are deterministic; the penguin will receive a reward of +3 if it reaches the fish, -2 if it slips into a hole and 0 otherwise.\n",
"\n",
"Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eBQ7lTpJQBSe"
},
"outputs": [],
"source": [ "source": [
"# We're going to work on the problem depicted in figure 19.10a\n", "# We're going to work on the problem depicted in figure 19.10a\n",
"n_rows = 4; n_cols = 4\n", "n_rows = 4; n_cols = 4\n",
"layout = np.zeros(n_rows * n_cols)\n", "layout = np.zeros(n_rows * n_cols)\n",
"reward_structure = np.zeros(n_rows * n_cols)\n", "reward_structure = np.zeros(n_rows * n_cols)\n",
"layout[9] = 1 ; reward_structure[9] = -2\n", "layout[9] = 1 ; reward_structure[9] = -2 # Hole\n",
"layout[10] = 1; reward_structure[10] = -2\n", "layout[10] = 1; reward_structure[10] = -2 # Hole\n",
"layout[14] = 1; reward_structure[14] = -2\n", "layout[14] = 1; reward_structure[14] = -2 # Hole\n",
"layout[15] = 2; reward_structure[15] = 3\n", "layout[15] = 2; reward_structure[15] = 3 # Fish\n",
"initial_state = 0\n", "initial_state = 0\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, state = initial_state, rewards=reward_structure, draw_state_index = True)" "mdp_drawer.draw(layout, state = initial_state, rewards=reward_structure, draw_state_index = True)"
], ]
"metadata": {
"id": "eBQ7lTpJQBSe"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"For clarity, the black numbers are the state number and the red numbers are the reward for being in that state. Note that the states are indexed from 0 rather than 1 as in the book to make the code neater."
],
"metadata": { "metadata": {
"id": "6Vku6v_se2IG" "id": "6Vku6v_se2IG"
} },
"source": [
"For clarity, the black numbers are the state number and the red numbers are the reward for being in that state. Note that the states are indexed from 0 rather than 1 as in the book to make the code neater."
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "Fhc6DzZNOjiC"
},
"source": [ "source": [
"Now let's define the state transition function $Pr(s_{t+1}|s_{t},a)$ in full where $a$ is the actions. Here $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ right. However, the ice is slippery, so we don't always go the direction we want to.\n", "Now let's define the state transition function $Pr(s_{t+1}|s_{t},a)$ in full where $a$ is the actions. Here $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ right. However, the ice is slippery, so we don't always go the direction we want to.\n",
"\n", "\n",
"Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better" "Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better"
], ]
"metadata": {
"id": "Fhc6DzZNOjiC"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wROjgnqh76xN"
},
"outputs": [],
"source": [ "source": [
"transition_probabilities_given_action0 = np.array(\\\n", "transition_probabilities_given_action0 = np.array(\\\n",
"[[0.00 , 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.90, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.85, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.33, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.85, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.90, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.34, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.34, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.75, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.75 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.10, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00]])\n",
"])\n", "\n",
"\n", "\n",
"transition_probabilities_given_action1 = np.array(\\\n", "transition_probabilities_given_action1 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.10, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.75 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.85, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.50, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.85, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.85, 0.90, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.25 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.25, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.10, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.75, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00]])\n",
"])\n", "\n",
"\n", "\n",
"transition_probabilities_given_action2 = np.array(\\\n", "transition_probabilities_given_action2 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.10, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.25 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.25, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.10, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.75 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.85, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.75, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.33, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.90, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.85, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00, 0.50 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.85, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00]])\n",
"])\n",
"\n", "\n",
"transition_probabilities_given_action3 = np.array(\\\n", "transition_probabilities_given_action3 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", "[[0.90, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.05, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.50, 0.00, 0.75, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.05, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.10, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.50 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.05, 0.00, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.90, 0.85, 0.00, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.85, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.75 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n", " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00]])\n",
"])\n", "\n",
"\n",
"\n", "\n",
"# Store all of these in a three dimension array\n", "# Store all of these in a three dimension array\n",
"# Pr(s_{t+1}=2|s_{t}=1, a_{t}=3] is stored at position [2,1,3]\n", "# Pr(s_{t+1}=2|s_{t}=1, a_{t}=3] is stored at position [2,1,3]\n",
"transition_probabilities_given_action = np.concatenate((np.expand_dims(transition_probabilities_given_action0,2),\n", "transition_probabilities_given_action = np.concatenate((np.expand_dims(transition_probabilities_given_action0,2),\n",
" np.expand_dims(transition_probabilities_given_action1,2),\n", " np.expand_dims(transition_probabilities_given_action1,2),\n",
" np.expand_dims(transition_probabilities_given_action2,2),\n", " np.expand_dims(transition_probabilities_given_action2,2),\n",
" np.expand_dims(transition_probabilities_given_action3,2)),axis=2)" " np.expand_dims(transition_probabilities_given_action3,2)),axis=2)\n",
], "\n",
"metadata": { "print('Grid Size:', len(transition_probabilities_given_action[0]))\n",
"id": "l7rT78BbOgTi" "print()\n",
"print('Transition Probabilities for when next state = 2:')\n",
"print(transition_probabilities_given_action[2])\n",
"print()\n",
"print('Transitions Probabilities for when next state = 2 and current state = 1')\n",
"print(transition_probabilities_given_action[2][1])\n",
"print()\n",
"print('Transitions Probabilities for when next state = 2 and current state = 1 and action = 3 (Left):')\n",
"print(transition_probabilities_given_action[2][1][3])"
]
}, },
"execution_count": null, {
"outputs": [] "cell_type": "markdown",
"metadata": {
"id": "eblSQ6xZ76xN"
},
"source": [
"## Implementation Details\n",
"\n",
"We provide the following methods:\n",
"- **`markov_decision_process_step`** - this function simulates $Pr(s_{t+1} | s_{t}, a_{t})$. It randomly selects an action, updates the state based on the transition probabilities associated with the chosen action, and returns the new state, the reward obtained for leaving the current state, and the chosen action. The randomness in action selection and state transitions reflects a random exploration process and the stochastic nature of the MDP, respectively.\n",
"\n",
"- **`get_policy`** - this function computes a policy that acts greedily with respect to the state-action values. The policy is computed for all states and the action that maximizes the state-action value is chosen for each state. When there are multiple optimal actions, one is chosen at random.\n",
"\n",
"\n",
"You have to implement the following method:\n",
"\n",
"- **`q_learning_step`** - this function implements a single step of the Q-learning algorithm for reinforcement learning as shown below. The update follows the Q-learning formula and is controlled by parameters such as the learning rate (alpha) and the discount factor $(\\gamma)$. The function returns the updated state-action values matrix.\n",
"\n",
"$Q(s, a) \\leftarrow (1 - \\alpha) \\cdot Q(s, a) + \\alpha \\cdot \\left(r + \\gamma \\cdot \\max_{a'} Q(s', a')\\right)$"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cKLn4Iam76xN"
},
"outputs": [],
"source": [ "source": [
"def q_learning_step(state_action_values, reward, state, new_state, action, gamma, alpha = 0.1):\n", "def get_policy(state_action_values):\n",
" policy = np.zeros(state_action_values.shape[1]) # One action for each state\n",
" for state in range(state_action_values.shape[1]):\n",
" # Break ties for maximising actions randomly\n",
" policy[state] = np.random.choice(np.flatnonzero(state_action_values[:, state] == max(state_action_values[:, state])))\n",
" return policy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "akjrncMF-FkU"
},
"outputs": [],
"source": [
"def markov_decision_process_step(state, transition_probabilities_given_action, reward_structure, terminal_states, action=None):\n",
" # Pick action\n",
" if action is None:\n",
" action = np.random.randint(4)\n",
" # Update the state\n",
" new_state = np.random.choice(a=range(transition_probabilities_given_action.shape[0]), p = transition_probabilities_given_action[:, state,action])\n",
"\n",
" # Return the reward -- here the reward is for arriving at the state\n",
" reward = reward_structure[new_state]\n",
" is_terminal = new_state in [terminal_states]\n",
"\n",
" return new_state, reward, action, is_terminal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5pO6-9ACWhiV"
},
"outputs": [],
"source": [
"def q_learning_step(state_action_values, reward, state, new_state, action, is_terminal, gamma, alpha = 0.1):\n",
" # TODO -- write this function\n", " # TODO -- write this function\n",
" # Replace this line\n", " # Replace this line\n",
" state_action_values_after = np.copy(state_action_values)\n", " state_action_values_after = np.copy(state_action_values)\n",
"\n", "\n",
" return state_action_values_after" " return state_action_values_after"
], ]
"metadata": {
"id": "5pO6-9ACWhiV"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"metadata": {
"id": "u4OHTTk176xO"
},
"source": [ "source": [
"# This takes a single step from an MDP which just has a completely random policy\n", "Lets run this for a single Q-learning step"
"def markov_decision_process_step(state, transition_probabilities_given_action, reward_structure):\n", ]
" # Pick action\n",
" action = np.random.randint(4)\n",
" # Update the state\n",
" new_state = np.random.choice(a=np.arange(0,transition_probabilities_given_action.shape[0]),p = transition_probabilities_given_action[:,state,action])\n",
" # Return the reward -- here the reward is for leaving the state\n",
" reward = reward_structure[state]\n",
"\n",
" return new_state, reward, action"
],
"metadata": {
"id": "akjrncMF-FkU"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fu5_VjvbSwfJ"
},
"outputs": [],
"source": [ "source": [
"# Initialize the state-action values to random numbers\n", "# Initialize the state-action values to random numbers\n",
"np.random.seed(0)\n", "np.random.seed(0)\n",
"n_state = transition_probabilities_given_action.shape[0]\n", "n_state = transition_probabilities_given_action.shape[0]\n",
"n_action = transition_probabilities_given_action.shape[2]\n", "n_action = transition_probabilities_given_action.shape[2]\n",
"terminal_states=[15]\n",
"state_action_values = np.random.normal(size=(n_action, n_state))\n", "state_action_values = np.random.normal(size=(n_action, n_state))\n",
"# Hard code value of termination state of finding fish to 0\n",
"state_action_values[:, terminal_states] = 0\n",
"gamma = 0.9\n", "gamma = 0.9\n",
"\n", "\n",
"policy = np.argmax(state_action_values, axis=0).astype(int)\n", "policy = get_policy(state_action_values)\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n", "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n",
"\n", "\n",
"# Now let's simulate a single Q-learning step\n", "# Now let's simulate a single Q-learning step\n",
"initial_state = 9\n", "initial_state = 9\n",
"print(\"Initial state = \", initial_state)\n", "print(\"Initial state =\",initial_state)\n",
"new_state, reward, action = markov_decision_process_step(initial_state, transition_probabilities_given_action, reward_structure)\n", "new_state, reward, action, is_terminal = markov_decision_process_step(initial_state, transition_probabilities_given_action, reward_structure, terminal_states)\n",
"print(\"Action = \", action)\n", "print(\"Action =\",action)\n",
"print(\"New state = \", new_state)\n", "print(\"New state =\",new_state)\n",
"print(\"Reward = \", reward)\n", "print(\"Reward =\", reward)\n",
"\n", "\n",
"state_action_values_after = q_learning_step(state_action_values, reward, initial_state, new_state, action, gamma)\n", "state_action_values_after = q_learning_step(state_action_values, reward, initial_state, new_state, action, is_terminal, gamma)\n",
"print(\"Your value:\",state_action_values_after[action, initial_state])\n", "print(\"Your value:\",state_action_values_after[action, initial_state])\n",
"print(\"True value: 0.27650262412468796\")\n", "print(\"True value: 0.3024718977397814\")\n",
"\n", "\n",
"policy = np.argmax(state_action_values, axis=0).astype(int)\n", "policy = get_policy(state_action_values)\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values_after, rewards = reward_structure)\n" "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values_after, rewards = reward_structure)\n"
], ]
"metadata": {
"id": "Fu5_VjvbSwfJ"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"Now let's run this for a while and watch the policy improve"
],
"metadata": { "metadata": {
"id": "Ogh0qucmb68J" "id": "Ogh0qucmb68J"
} },
"source": [
"Now let's run this for a while (20000) steps and watch the policy improve"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N6gFYifh76xO"
},
"outputs": [],
"source": [ "source": [
"# Initialize the state-action values to random numbers\n", "# Initialize the state-action values to random numbers\n",
"np.random.seed(0)\n", "np.random.seed(0)\n",
"n_state = transition_probabilities_given_action.shape[0]\n", "n_state = transition_probabilities_given_action.shape[0]\n",
"n_action = transition_probabilities_given_action.shape[2]\n", "n_action = transition_probabilities_given_action.shape[2]\n",
"state_action_values = np.random.normal(size=(n_action, n_state))\n", "state_action_values = np.random.normal(size=(n_action, n_state))\n",
"# Hard code termination state of finding fish\n", "\n",
"state_action_values[:,n_state-1] = 3.0\n", "# Hard code value of termination state of finding fish to 0\n",
"terminal_states = [15]\n",
"state_action_values[:, terminal_states] = 0\n",
"gamma = 0.9\n", "gamma = 0.9\n",
"\n", "\n",
"# Draw the initial setup\n", "# Draw the initial setup\n",
"policy = np.argmax(state_action_values, axis=0).astype(int)\n", "print('Initial Policy:')\n",
"policy = get_policy(state_action_values)\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n", "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n",
"\n", "\n",
"\n", "state = np.random.randint(n_state-1)\n",
"state= np.random.randint(n_state-1)\n",
"\n", "\n",
"# Run for a number of iterations\n", "# Run for a number of iterations\n",
"for c_iter in range(10000):\n", "for c_iter in range(20000):\n",
" new_state, reward, action = markov_decision_process_step(state, transition_probabilities_given_action, reward_structure)\n", " new_state, reward, action, is_terminal = markov_decision_process_step(state, transition_probabilities_given_action, reward_structure, terminal_states)\n",
" state_action_values_after = q_learning_step(state_action_values, reward, state, new_state, action, gamma)\n", " state_action_values_after = q_learning_step(state_action_values, reward, state, new_state, action, is_terminal, gamma)\n",
"\n",
" # If in termination state, reset state randomly\n", " # If in termination state, reset state randomly\n",
" if new_state==15:\n", " if is_terminal:\n",
" state= np.random.randint(n_state-1)\n", " state = np.random.randint(n_state-1)\n",
" else:\n", " else:\n",
" state = new_state\n", " state = new_state\n",
" # Update the policy\n",
" state_action_values = np.copy(state_action_values_after)\n",
" policy = np.argmax(state_action_values, axis=0).astype(int)\n",
"\n", "\n",
" # Update the policy\n",
" state_action_values = deepcopy(state_action_values_after)\n",
" policy = get_policy(state_action_values_after)\n",
"\n",
"print('Final Optimal Policy:')\n",
"# Draw the final situation\n", "# Draw the final situation\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)" "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "djPTKuDk76xO"
},
"source": [
"Finally, lets run this for a **single** episode and visualize the penguin's actions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pWObQf2h76xO"
},
"outputs": [],
"source": [
"def get_one_episode(n_state, state_action_values, terminal_states, gamma):\n",
"\n",
" state = np.random.randint(n_state-1)\n",
"\n",
" # Create lists to store all the states seen and actions taken throughout the single episode\n",
" all_states = []\n",
" all_actions = []\n",
"\n",
" # Initalize episode termination flag\n",
" done = False\n",
" # Initialize counter for steps in the episode\n",
" steps = 0\n",
"\n",
" all_states.append(state)\n",
"\n",
" while not done:\n",
" steps += 1\n",
"\n",
" new_state, reward, action, is_terminal = markov_decision_process_step(state, transition_probabilities_given_action, reward_structure, terminal_states)\n",
" all_states.append(new_state)\n",
" all_actions.append(action)\n",
"\n",
" state_action_values_after = q_learning_step(state_action_values, reward, state, new_state, action, is_terminal, gamma)\n",
"\n",
" # If in termination state, reset state randomly\n",
" if is_terminal:\n",
" state = np.random.randint(n_state-1)\n",
" print(f'Episode Terminated at {steps} Steps')\n",
" # Set episode termination flag\n",
" done = True\n",
" else:\n",
" state = new_state\n",
"\n",
" # Update the policy\n",
" state_action_values = deepcopy(state_action_values_after)\n",
" policy = get_policy(state_action_values_after)\n",
"\n",
" return all_states, all_actions, policy, state_action_values\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P7cbCGT176xO"
},
"outputs": [],
"source": [
"def visualize_one_episode(states, actions):\n",
" # Define actions for visualization\n",
" acts = ['up', 'right', 'down', 'left']\n",
"\n",
" # Iterate over the states and actions\n",
" for i in range(len(states)):\n",
"\n",
" if i == 0:\n",
" print('Starting State:', states[i])\n",
"\n",
" elif i == len(states)-1:\n",
" print('Episode Done:', states[i])\n",
"\n",
" else:\n",
" print('State', states[i-1])\n",
" a = actions[i]\n",
" print('Action:', acts[a])\n",
" print('Next State:', states[i])\n",
"\n",
" # Visualize the current state using the MDP drawer\n",
" mdp_drawer.draw(layout, state=states[i], rewards=reward_structure, draw_state_index=True)\n",
" clear_output(True)\n",
"\n",
" # Pause for a short duration to allow observation\n",
" sleep(1.5)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cr98F8PT76xP"
},
"outputs": [],
"source": [
"# Initialize the state-action values to random numbers\n",
"np.random.seed(2)\n",
"n_state = transition_probabilities_given_action.shape[0]\n",
"n_action = transition_probabilities_given_action.shape[2]\n",
"state_action_values = np.random.normal(size=(n_action, n_state))\n",
"\n",
"# Hard code value of termination state of finding fish to 0\n",
"terminal_states = [15]\n",
"state_action_values[:, terminal_states] = 0\n",
"gamma = 0.9\n",
"\n",
"# Draw the initial setup\n",
"print('Initial Policy:')\n",
"policy = get_policy(state_action_values)\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n",
"\n",
"states, actions, policy, state_action_values = get_one_episode(n_state, state_action_values, terminal_states, gamma)\n",
"\n",
"print()\n",
"print('Final Optimal Policy:')\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5zBu1g3776xP"
},
"outputs": [],
"source": [
"visualize_one_episode(states, actions)"
]
}
], ],
"metadata": { "metadata": {
"id": "qQFhwVqPcCFH" "colab": {
"provenance": [],
"include_colab_link": true
}, },
"execution_count": null, "kernelspec": {
"outputs": [] "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
} }
] },
"nbformat": 4,
"nbformat_minor": 0
} }

Binary file not shown.

Binary file not shown.

View File

@@ -14,9 +14,9 @@
<br>Published by MIT Press Dec 5th 2023.<br> <br>Published by MIT Press Dec 5th 2023.<br>
<ul> <ul>
<li> <li>
<p style="font-size: larger; margin-bottom: 0">Download draft PDF Chapters 1-21 <a <p style="font-size: larger; margin-bottom: 0">Download full PDF <a
href="https://github.com/udlbook/udlbook/releases/download/v1.19/UnderstandingDeepLearning_16_12_23_C.pdf">here</a> href="https://github.com/udlbook/udlbook/releases/download/v2.0.1/UnderstandingDeepLearning_02_15_24_C.pdf">here</a>
</p>2024-01-16. CC-BY-NC-ND license<br> </p>2024-02-15. CC-BY-NC-ND license<br>
<img src="https://img.shields.io/github/downloads/udlbook/udlbook/total" alt="download stats shield"> <img src="https://img.shields.io/github/downloads/udlbook/udlbook/total" alt="download stats shield">
</li> </li>
<li> Order your copy from <a href="https://mitpress.mit.edu/9780262048644/understanding-deep-learning/">here </a></li> <li> Order your copy from <a href="https://mitpress.mit.edu/9780262048644/understanding-deep-learning/">here </a></li>