Files
udlbook/Notebooks/Chap19/19_2_Dynamic_Programming.ipynb
Mark Gotham 9649ce382b "TO DO" > "TODO
In [commit 6072ad4](6072ad4), @KajvanRijn kindly changed all "TO DO" to "TODO" in the code blocks. That's useful. In addition, it should be changed (as here) in the instructions. Then there's no doubt or issue for anyone searching all instances.
2025-02-11 15:11:06 +00:00

531 lines
31 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOlD6kmCxX3SKKuh3oJikKA",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap19/19_2_Dynamic_Programming.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Notebook 19.2: Dynamic programming**\n",
"\n",
"This notebook investigates the dynamic programming approach to tabular reinforcement learning as described in figure 19.10 of the book.\n",
"\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
],
"metadata": {
"id": "t9vk9Elugvmi"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image"
],
"metadata": {
"id": "OLComQyvCIJ7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Get local copies of components of images\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Empty.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Hole.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Fish.png\n",
"!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Penguin.png"
],
"metadata": {
"id": "ZsvrUszPLyEG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Ugly class that takes care of drawing pictures like in the book.\n",
"# You can totally ignore this code!\n",
"class DrawMDP:\n",
" # Constructor initializes parameters\n",
" def __init__(self, n_row, n_col):\n",
" self.empty_image = np.asarray(Image.open('Empty.png'))\n",
" self.hole_image = np.asarray(Image.open('Hole.png'))\n",
" self.fish_image = np.asarray(Image.open('Fish.png'))\n",
" self.penguin_image = np.asarray(Image.open('Penguin.png'))\n",
" self.fig,self.ax = plt.subplots()\n",
" self.n_row = n_row\n",
" self.n_col = n_col\n",
"\n",
" my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n",
" my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n",
" r = np.floor(my_colormap_vals_dec/(256*256))\n",
" g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n",
" b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
" self.colormap = np.vstack((r,g,b)).transpose()/255.0\n",
"\n",
"\n",
" def draw_text(self, text, row, col, position, color):\n",
" if position == 'bc':\n",
" self.ax.text( 83*col+41,83 * (row+1) -10, text, horizontalalignment=\"center\", color=color, fontweight='bold')\n",
" if position == 'tl':\n",
" self.ax.text( 83*col+5,83 * row +5, text, verticalalignment = 'top', horizontalalignment=\"left\", color=color, fontweight='bold')\n",
" if position == 'tr':\n",
" self.ax.text( 83*(col+1)-5, 83 * row +5, text, verticalalignment = 'top', horizontalalignment=\"right\", color=color, fontweight='bold')\n",
"\n",
" # Draws a set of states\n",
" def draw_path(self, path, color1, color2):\n",
" for i in range(len(path)-1):\n",
" row_start = np.floor(path[i]/self.n_col)\n",
" row_end = np.floor(path[i+1]/self.n_col)\n",
" col_start = path[i] - row_start * self.n_col\n",
" col_end = path[i+1] - row_end * self.n_col\n",
"\n",
" color_index = int(np.floor(255 * i/(len(path)-1.)))\n",
" self.ax.plot([col_start * 83+41 + i, col_end * 83+41 + i ],[row_start * 83+41 + i, row_end * 83+41 + i ], color=(self.colormap[color_index,0],self.colormap[color_index,1],self.colormap[color_index,2]))\n",
"\n",
"\n",
" # Draw deterministic policy\n",
" def draw_deterministic_policy(self,i, action):\n",
" row = np.floor(i/self.n_col)\n",
" col = i - row * self.n_col\n",
" center_x = 83 * col + 41\n",
" center_y = 83 * row + 41\n",
" arrow_base_width = 10\n",
" arrow_height = 15\n",
" # Draw arrow pointing upward\n",
" if action ==0:\n",
" triangle_indices = np.array([[center_x, center_y-arrow_height/2],\n",
" [center_x - arrow_base_width/2, center_y+arrow_height/2],\n",
" [center_x + arrow_base_width/2, center_y+arrow_height/2]])\n",
" # Draw arrow pointing right\n",
" if action ==1:\n",
" triangle_indices = np.array([[center_x + arrow_height/2, center_y],\n",
" [center_x - arrow_height/2, center_y-arrow_base_width/2],\n",
" [center_x - arrow_height/2, center_y+arrow_base_width/2]])\n",
" # Draw arrow pointing downward\n",
" if action ==2:\n",
" triangle_indices = np.array([[center_x, center_y+arrow_height/2],\n",
" [center_x - arrow_base_width/2, center_y-arrow_height/2],\n",
" [center_x + arrow_base_width/2, center_y-arrow_height/2]])\n",
" # Draw arrow pointing left\n",
" if action ==3:\n",
" triangle_indices = np.array([[center_x - arrow_height/2, center_y],\n",
" [center_x + arrow_height/2, center_y-arrow_base_width/2],\n",
" [center_x + arrow_height/2, center_y+arrow_base_width/2]])\n",
" self.ax.fill(triangle_indices[:,0], triangle_indices[:,1],facecolor='cyan', edgecolor='darkcyan', linewidth=1)\n",
"\n",
" # Draw stochastic policy\n",
" def draw_stochastic_policy(self,i, action_probs):\n",
" row = np.floor(i/self.n_col)\n",
" col = i - row * self.n_col\n",
" offset = 20\n",
" # Draw arrow pointing upward\n",
" center_x = 83 * col + 41\n",
" center_y = 83 * row + 41 - offset\n",
" arrow_base_width = 15 * action_probs[0]\n",
" arrow_height = 20 * action_probs[0]\n",
" triangle_indices = np.array([[center_x, center_y-arrow_height/2],\n",
" [center_x - arrow_base_width/2, center_y+arrow_height/2],\n",
" [center_x + arrow_base_width/2, center_y+arrow_height/2]])\n",
" self.ax.fill(triangle_indices[:,0], triangle_indices[:,1],facecolor='cyan', edgecolor='darkcyan', linewidth=1)\n",
"\n",
" # Draw arrow pointing right\n",
" center_x = 83 * col + 41 + offset\n",
" center_y = 83 * row + 41\n",
" arrow_base_width = 15 * action_probs[1]\n",
" arrow_height = 20 * action_probs[1]\n",
" triangle_indices = np.array([[center_x + arrow_height/2, center_y],\n",
" [center_x - arrow_height/2, center_y-arrow_base_width/2],\n",
" [center_x - arrow_height/2, center_y+arrow_base_width/2]])\n",
" self.ax.fill(triangle_indices[:,0], triangle_indices[:,1],facecolor='cyan', edgecolor='darkcyan', linewidth=1)\n",
"\n",
" # Draw arrow pointing downward\n",
" center_x = 83 * col + 41\n",
" center_y = 83 * row + 41 +offset\n",
" arrow_base_width = 15 * action_probs[2]\n",
" arrow_height = 20 * action_probs[2]\n",
" triangle_indices = np.array([[center_x, center_y+arrow_height/2],\n",
" [center_x - arrow_base_width/2, center_y-arrow_height/2],\n",
" [center_x + arrow_base_width/2, center_y-arrow_height/2]])\n",
" self.ax.fill(triangle_indices[:,0], triangle_indices[:,1],facecolor='cyan', edgecolor='darkcyan', linewidth=1)\n",
"\n",
" # Draw arrow pointing left\n",
" center_x = 83 * col + 41 -offset\n",
" center_y = 83 * row + 41\n",
" arrow_base_width = 15 * action_probs[3]\n",
" arrow_height = 20 * action_probs[3]\n",
" triangle_indices = np.array([[center_x - arrow_height/2, center_y],\n",
" [center_x + arrow_height/2, center_y-arrow_base_width/2],\n",
" [center_x + arrow_height/2, center_y+arrow_base_width/2]])\n",
" self.ax.fill(triangle_indices[:,0], triangle_indices[:,1],facecolor='cyan', edgecolor='darkcyan', linewidth=1)\n",
"\n",
"\n",
" def draw(self, layout, state=None, draw_state_index= False, rewards=None, policy=None, state_values=None, action_values=None,path1=None, path2 = None):\n",
" # Construct the image\n",
" image_out = np.zeros((self.n_row * 83, self.n_col * 83, 4),dtype='uint8')\n",
" for c_row in range (self.n_row):\n",
" for c_col in range(self.n_col):\n",
" if layout[c_row * self.n_col + c_col]==0:\n",
" image_out[c_row*83:c_row*83+83, c_col*83:c_col*83+83,:] = self.empty_image\n",
" elif layout[c_row * self.n_col + c_col]==1:\n",
" image_out[c_row*83:c_row*83+83, c_col*83:c_col*83+83,:] = self.hole_image\n",
" else:\n",
" image_out[c_row*83:c_row*83+83, c_col*83:c_col*83+83,:] = self.fish_image\n",
" if state is not None and state == c_row * self.n_col + c_col:\n",
" image_out[c_row*83:c_row*83+83, c_col*83:c_col*83+83,:] = self.penguin_image\n",
"\n",
" # Draw the image\n",
" plt.imshow(image_out)\n",
" self.ax.get_xaxis().set_visible(False)\n",
" self.ax.get_yaxis().set_visible(False)\n",
" self.ax.spines['top'].set_visible(False)\n",
" self.ax.spines['right'].set_visible(False)\n",
" self.ax.spines['bottom'].set_visible(False)\n",
" self.ax.spines['left'].set_visible(False)\n",
"\n",
" if draw_state_index:\n",
" for c_cell in range(layout.size):\n",
" self.draw_text(\"%d\"%(c_cell), np.floor(c_cell/self.n_col), c_cell-np.floor(c_cell/self.n_col)*self.n_col,'tl','k')\n",
"\n",
" # Draw the policy as triangles\n",
" if policy is not None:\n",
" # If the policy is deterministic\n",
" if len(policy) == len(layout):\n",
" for i in range(len(layout)):\n",
" self.draw_deterministic_policy(i, policy[i])\n",
" # Else it is stochastic\n",
" else:\n",
" for i in range(len(layout)):\n",
" self.draw_stochastic_policy(i,policy[:,i])\n",
"\n",
"\n",
" if path1 is not None:\n",
" self.draw_path(path1, np.array([1.0, 0.0, 0.0]), np.array([0.0, 1.0, 1.0]))\n",
"\n",
" if rewards is not None:\n",
" for c_cell in range(layout.size):\n",
" self.draw_text(\"%d\"%(rewards[c_cell]), np.floor(c_cell/self.n_col), c_cell-np.floor(c_cell/self.n_col)*self.n_col,'tr','r')\n",
"\n",
" if state_values is not None:\n",
" for c_cell in range(layout.size):\n",
" self.draw_text(\"%2.2f\"%(state_values[c_cell]), np.floor(c_cell/self.n_col), c_cell-np.floor(c_cell/self.n_col)*self.n_col,'bc','hotpink')\n",
"\n",
"\n",
" plt.show()"
],
"metadata": {
"id": "Gq1HfJsHN3SB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# We're going to work on the problem depicted in figure 19.10a\n",
"n_rows = 4; n_cols = 4\n",
"layout = np.zeros(n_rows * n_cols)\n",
"rewards = np.zeros(n_rows * n_cols)\n",
"layout[9] = 1 ; rewards[9] = -2\n",
"layout[10] = 1; rewards[10] = -2\n",
"layout[14] = 1; rewards[14] = -2\n",
"layout[15] = 2; rewards[15] = 3\n",
"initial_state = 0\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, state = initial_state, rewards=rewards, draw_state_index = True)"
],
"metadata": {
"id": "eBQ7lTpJQBSe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"For clarity, the black numbers are the state number and the red numbers are the reward for being in that state. Note that the states are indexed from 0 rather than 1 as in the book to make the code neater."
],
"metadata": {
"id": "6Vku6v_se2IG"
}
},
{
"cell_type": "markdown",
"source": [
"Define a step from the Markov process"
],
"metadata": {
"id": "axllRDDuDDLS"
}
},
{
"cell_type": "markdown",
"source": [
"Now let's define the state transition function $Pr(s_{t+1}|s_{t},a)$ in full where $a$ is the actions. Here $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ right. However, the ice is slippery, so we don't always go the direction we want to.\n",
"\n",
"Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better"
],
"metadata": {
"id": "Fhc6DzZNOjiC"
}
},
{
"cell_type": "code",
"source": [
"transition_probabilities_given_action0 = np.array(\\\n",
"[[0.00 , 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.50 , 0.00, 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.33, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.50 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.34, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.34, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.75, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.75 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n",
"])\n",
"\n",
"transition_probabilities_given_action1 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.75 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.50, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.25 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.25, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.25, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.75, 0.00, 0.25, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00 ],\n",
"])\n",
"\n",
"transition_probabilities_given_action2 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.25 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.25, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.75 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.75, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.33, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00, 0.50 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00 ],\n",
"])\n",
"\n",
"transition_probabilities_given_action3 = np.array(\\\n",
"[[0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.50 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.50, 0.00, 0.75, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.50 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50, 0.00 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.75 ],\n",
" [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n",
"])\n",
"\n",
"# Store all of these in a three dimension array\n",
"# Pr(s_{t+1}=2|s_{t}=1, a_{t}=3] is stored at position [2,1,3]\n",
"transition_probabilities_given_action = np.concatenate((np.expand_dims(transition_probabilities_given_action0,2),\n",
" np.expand_dims(transition_probabilities_given_action1,2),\n",
" np.expand_dims(transition_probabilities_given_action2,2),\n",
" np.expand_dims(transition_probabilities_given_action3,2)),axis=2)"
],
"metadata": {
"id": "l7rT78BbOgTi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Update the state values for the current policy, by making the values at adjacent\n",
"# states compatible with the Bellman equation (equation 19.11)\n",
"def policy_evaluation(policy, state_values, rewards, transition_probabilities_given_action, gamma):\n",
"\n",
" n_state = len(state_values)\n",
" state_values_new = np.zeros_like(state_values)\n",
"\n",
" for state in range(n_state):\n",
" # Special case -- bottom right is terminating state, always just rewards 3.0\n",
" if state == 15:\n",
" state_values_new[state] = 3.0\n",
" break\n",
"\n",
" return state_values_new\n",
"\n",
"# Greedily choose the action that maximizes the value for each state.\n",
"def policy_improvement(state_values, rewards, transition_probabilities_given_action, gamma):\n",
" policy = np.zeros_like(state_values, dtype='uint8')\n",
" for state in range(15):\n",
" # TODO -- Write this function (from equation 19.12)\n",
" # Replace this line\n",
" policy[state] = 1\n",
"\n",
"\n",
" return policy\n",
"\n",
"\n",
"# Main routine -- alternately call policy evaluation and policy improvement\n",
"def dynamic_programming(policy, state_values, rewards, transition_probabilities_given_action, gamma, n_iter, verbose = False):\n",
"\n",
" for c_iter in range(n_iter):\n",
" print(\"Iteration %d\"%(c_iter))\n",
"\n",
" state_values = policy_evaluation(policy, state_values, rewards, transition_probabilities_given_action, gamma)\n",
"\n",
" if verbose:\n",
" print(\"Updated state values\")\n",
" print(\"Policy: \", policy)\n",
" print(\"State values:\", state_values)\n",
" mdp_drawer = DrawMDP(n_rows, n_cols)\n",
" mdp_drawer.draw(layout, policy = policy, state_values=state_values)\n",
"\n",
" policy = policy_improvement(state_values, rewards, transition_probabilities_given_action, gamma)\n",
"\n",
" if verbose:\n",
" print(\"Updated policy values\")\n",
" print(\"Policy:\", policy)\n",
" print(\"State_values\", state_values)\n",
" mdp_drawer = DrawMDP(n_rows, n_cols)\n",
" mdp_drawer.draw(layout, policy = policy, state_values=state_values)\n",
"\n",
" return policy, state_values\n"
],
"metadata": {
"id": "bFYvF9nAloIA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Set seed so random numbers always the same\n",
"np.random.seed(0)\n",
"\n",
"# Let's start with by setting the policy randomly\n",
"policy = np.random.choice(size= n_rows * n_cols, a=np.arange(0,4,1))\n",
"state_values = np.zeros(n_rows* n_cols)\n",
"\n",
"# Let's draw the policy first\n",
"print(\"Initial state\")\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, rewards = rewards, state_values=state_values, draw_state_index = True)\n",
"\n",
"n_iter = 2\n",
"gamma = 0.9\n",
"policy, state_values = dynamic_programming(policy, state_values, rewards, transition_probabilities_given_action, gamma, n_iter, verbose=True)\n",
"\n",
"print(\"Your state values=\", state_values)\n",
"print(\"True values= [ 0. 0. 0. 0. 0. -0.288 -0.288 0. -0.45 -2.288 -2.594 0.9 0. -0.9 -1.1 3. ] \", )"
],
"metadata": {
"id": "8jWhDlkaKj7Q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now let's run it for a series of iterations without drawing."
],
"metadata": {
"id": "wdcecFKlx97N"
}
},
{
"cell_type": "code",
"source": [
"# Set seed so random numbers always the same\n",
"np.random.seed(0)\n",
"\n",
"# Let's start with by setting the policy randomly\n",
"policy = np.random.choice(size= n_rows * n_cols, a=np.arange(0,4,1))\n",
"state_values = np.zeros(n_rows* n_cols)\n",
"\n",
"# Let's draw the policy first\n",
"print(\"Initial state\")\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, rewards = rewards, state_values=state_values, draw_state_index = True)\n",
"\n",
"n_iter = 20\n",
"gamma = 0.9\n",
"policy, state_values = dynamic_programming(policy, state_values, rewards, transition_probabilities_given_action, gamma, n_iter, verbose=False)\n",
"mdp_drawer = DrawMDP(n_rows, n_cols)\n",
"mdp_drawer.draw(layout, policy = policy, rewards = rewards, state_values=state_values, draw_state_index = True)\n",
"\n"
],
"metadata": {
"id": "rtsLUwi6ZEWL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"You should see that if we start at state 13, the actions have been selected to go all the way around the holes in the ice (keeping a wide berth to avoid slipping into them) and eventually converge on the fish."
],
"metadata": {
"id": "tvXOs9VhyWnh"
}
}
]
}