From 22d5bc320f23537c9b89f0ddcd338c51e0b4ebef Mon Sep 17 00:00:00 2001 From: udlbook <110402648+udlbook@users.noreply.github.com> Date: Mon, 4 Mar 2024 10:06:34 -0500 Subject: [PATCH] Created using Colaboratory --- .../19_4_Temporal_Difference_Methods.ipynb | 621 ++++++++++++------ 1 file changed, 431 insertions(+), 190 deletions(-) diff --git a/Notebooks/Chap19/19_4_Temporal_Difference_Methods.ipynb b/Notebooks/Chap19/19_4_Temporal_Difference_Methods.ipynb index 7135aec..d85443d 100644 --- a/Notebooks/Chap19/19_4_Temporal_Difference_Methods.ipynb +++ b/Notebooks/Chap19/19_4_Temporal_Difference_Methods.ipynb @@ -1,20 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "authorship_tag": "ABX9TyNEAhORON7DFN1dZMhDK/PO", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", @@ -28,6 +12,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "t9vk9Elugvmi" + }, "source": [ "# **Notebook 19.4: Temporal difference methods**\n", "\n", @@ -35,42 +22,49 @@ "\n", "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "\n", - "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." - ], - "metadata": { - "id": "t9vk9Elugvmi" - } + "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n", + "\n", + "Thanks to [Akshil Patel](https://www.akshilpatel.com) and [Jessica Nicholson](https://jessicanicholson1.github.io) for their help in preparing this notebook." + ] }, { "cell_type": "code", - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from PIL import Image" - ], + "execution_count": null, "metadata": { "id": "OLComQyvCIJ7" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "from IPython.display import clear_output\n", + "from time import sleep\n", + "from copy import deepcopy" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZsvrUszPLyEG" + }, + "outputs": [], "source": [ "# Get local copies of components of images\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Empty.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Hole.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Fish.png\n", "!wget https://raw.githubusercontent.com/udlbook/udlbook/main/Notebooks/Chap19/Penguin.png" - ], - "metadata": { - "id": "ZsvrUszPLyEG" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Gq1HfJsHN3SB" + }, + "outputs": [], "source": [ "# Ugly class that takes care of drawing pictures like in the book.\n", "# You can totally ignore this code!\n", @@ -253,269 +247,516 @@ " self.draw_text(\"%2.2f\"%(state_action_values[3, c_cell]), np.floor(c_cell/self.n_col), c_cell-np.floor(c_cell/self.n_col)*self.n_col,'lc','black')\n", "\n", " plt.show()" - ], + ] + }, + { + "cell_type": "markdown", "metadata": { - "id": "Gq1HfJsHN3SB" + "id": "JU8gX59o76xM" }, - "execution_count": null, - "outputs": [] + "source": [ + "# Penguin Ice Environment\n", + "\n", + "In this implementation we have designed an icy gridworld that a penguin has to traverse to reach the fish found in the bottom right corner.\n", + "\n", + "## Environment Description\n", + "\n", + "Consider having to cross an icy surface to reach the yummy fish. In order to achieve this task as quickly as possible, the penguin needs to waddle along as fast as it can whilst simultaneously avoiding falling into the holes.\n", + "\n", + "In this icy environment the penguin is at one of the discrete cells in the gridworld. The agent starts each episode on a randomly chosen cell. The environment state dynamics are captured by the transition probabilities $Pr(s_{t+1} |s_t, a_t)$ where $s_t$ is the current state, $a_t$ is the action chosen, and $s_{t+1}$ is the next state at decision stage t. At each decision stage, the penguin can move in one of four directions: $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ left.\n", + "\n", + "However, the ice is slippery, so we don't always go the direction we want to: every time the agent chooses an action, with 0.25 probability, the environment changes the action taken to a differenct action, which is uniformly sampled from the other available actions.\n", + "\n", + "The rewards are deterministic; the penguin will receive a reward of +3 if it reaches the fish, -2 if it slips into a hole and 0 otherwise.\n", + "\n", + "Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eBQ7lTpJQBSe" + }, + "outputs": [], "source": [ "# We're going to work on the problem depicted in figure 19.10a\n", "n_rows = 4; n_cols = 4\n", "layout = np.zeros(n_rows * n_cols)\n", "reward_structure = np.zeros(n_rows * n_cols)\n", - "layout[9] = 1 ; reward_structure[9] = -2\n", - "layout[10] = 1; reward_structure[10] = -2\n", - "layout[14] = 1; reward_structure[14] = -2\n", - "layout[15] = 2; reward_structure[15] = 3\n", + "layout[9] = 1 ; reward_structure[9] = -2 # Hole\n", + "layout[10] = 1; reward_structure[10] = -2 # Hole\n", + "layout[14] = 1; reward_structure[14] = -2 # Hole\n", + "layout[15] = 2; reward_structure[15] = 3 # Fish\n", "initial_state = 0\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer.draw(layout, state = initial_state, rewards=reward_structure, draw_state_index = True)" - ], - "metadata": { - "id": "eBQ7lTpJQBSe" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "For clarity, the black numbers are the state number and the red numbers are the reward for being in that state. Note that the states are indexed from 0 rather than 1 as in the book to make the code neater." - ], "metadata": { "id": "6Vku6v_se2IG" - } + }, + "source": [ + "For clarity, the black numbers are the state number and the red numbers are the reward for being in that state. Note that the states are indexed from 0 rather than 1 as in the book to make the code neater." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Fhc6DzZNOjiC" + }, "source": [ "Now let's define the state transition function $Pr(s_{t+1}|s_{t},a)$ in full where $a$ is the actions. Here $a=0$ means try to go upward, $a=1$, right, $a=2$ down and $a=3$ right. However, the ice is slippery, so we don't always go the direction we want to.\n", "\n", "Note that as for the states, we've indexed the actions from zero (unlike in the book) so they map to the indices of arrays better" - ], - "metadata": { - "id": "Fhc6DzZNOjiC" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wROjgnqh76xN" + }, + "outputs": [], "source": [ "transition_probabilities_given_action0 = np.array(\\\n", - "[[0.00 , 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.50 , 0.00, 0.33, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.33, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.50 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.34, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.34, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.75, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.25, 0.00, 0.00, 0.50, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.75 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.25 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n", - "])\n", + "[[0.90, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.05, 0.85, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.05, 0.85, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.05, 0.90, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.10, 0.05, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00]])\n", + "\n", "\n", "transition_probabilities_given_action1 = np.array(\\\n", - "[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.75 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.50, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.25 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.25, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.25, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00, 0.00, 0.25, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.75, 0.00, 0.25, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.50, 0.00 ],\n", - "])\n", + "[[0.10, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.85, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.85, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.85, 0.90, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.10, 0.05, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.05, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.85, 0.00]])\n", + "\n", "\n", "transition_probabilities_given_action2 = np.array(\\\n", - "[[0.00 , 0.25, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.25 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.25, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.75 , 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.75, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.17, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.17, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.25, 0.00, 0.00, 0.33, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.16, 0.00, 0.00, 0.00, 0.00, 0.50 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.33, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.33, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00, 0.50 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.34, 0.00 ],\n", - "])\n", + "[[0.10, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.05, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.05, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.05, 0.10, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.85, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.00, 0.90, 0.05, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.85, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.85, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00]])\n", "\n", "transition_probabilities_given_action3 = np.array(\\\n", - "[[0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.50 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.50, 0.00, 0.75, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.50, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.50 , 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.25, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.33, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.17, 0.00, 0.50, 0.00, 0.00, 0.25, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.17, 0.00, 0.00, 0.00, 0.00, 0.25 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.50, 0.00, 0.50, 0.00 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.16, 0.00, 0.00, 0.25, 0.00, 0.75 ],\n", - " [0.00 , 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.00, 0.00, 0.25, 0.00 ],\n", - "])\n", + "[[0.90, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.05, 0.05, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.05, 0.05, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.05, 0.10, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.05, 0.00, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.85, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00, 0.85, 0.00, 0.00, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00, 0.00, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.00, 0.90, 0.85, 0.00, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.85, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.05, 0.00],\n", + " [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.00, 0.05, 0.00]])\n", + "\n", + "\n", "\n", "# Store all of these in a three dimension array\n", "# Pr(s_{t+1}=2|s_{t}=1, a_{t}=3] is stored at position [2,1,3]\n", "transition_probabilities_given_action = np.concatenate((np.expand_dims(transition_probabilities_given_action0,2),\n", " np.expand_dims(transition_probabilities_given_action1,2),\n", " np.expand_dims(transition_probabilities_given_action2,2),\n", - " np.expand_dims(transition_probabilities_given_action3,2)),axis=2)" - ], + " np.expand_dims(transition_probabilities_given_action3,2)),axis=2)\n", + "\n", + "print('Grid Size:', len(transition_probabilities_given_action[0]))\n", + "print()\n", + "print('Transition Probabilities for when next state = 2:')\n", + "print(transition_probabilities_given_action[2])\n", + "print()\n", + "print('Transitions Probabilities for when next state = 2 and current state = 1')\n", + "print(transition_probabilities_given_action[2][1])\n", + "print()\n", + "print('Transitions Probabilities for when next state = 2 and current state = 1 and action = 3 (Left):')\n", + "print(transition_probabilities_given_action[2][1][3])" + ] + }, + { + "cell_type": "markdown", "metadata": { - "id": "l7rT78BbOgTi" + "id": "eblSQ6xZ76xN" }, - "execution_count": null, - "outputs": [] + "source": [ + "## Implementation Details\n", + "\n", + "We provide the following methods:\n", + "- **`markov_decision_process_step`** - this function simulates $Pr(s_{t+1} | s_{t}, a_{t})$. It randomly selects an action, updates the state based on the transition probabilities associated with the chosen action, and returns the new state, the reward obtained for leaving the current state, and the chosen action. The randomness in action selection and state transitions reflects a random exploration process and the stochastic nature of the MDP, respectively.\n", + "\n", + "- **`get_policy`** - this function computes a policy that acts greedily with respect to the state-action values. The policy is computed for all states and the action that maximizes the state-action value is chosen for each state. When there are multiple optimal actions, one is chosen at random.\n", + "\n", + "\n", + "You have to implement the following method:\n", + "\n", + "- **`q_learning_step`** - this function implements a single step of the Q-learning algorithm for reinforcement learning as shown below. The update follows the Q-learning formula and is controlled by parameters such as the learning rate (alpha) and the discount factor $(\\gamma)$. The function returns the updated state-action values matrix.\n", + "\n", + "$Q(s, a) \\leftarrow (1 - \\alpha) \\cdot Q(s, a) + \\alpha \\cdot \\left(r + \\gamma \\cdot \\max_{a'} Q(s', a')\\right)$" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cKLn4Iam76xN" + }, + "outputs": [], "source": [ - "def q_learning_step(state_action_values, reward, state, new_state, action, gamma, alpha = 0.1):\n", + "def get_policy(state_action_values):\n", + " policy = np.zeros(state_action_values.shape[1]) # One action for each state\n", + " for state in range(state_action_values.shape[1]):\n", + " # Break ties for maximising actions randomly\n", + " policy[state] = np.random.choice(np.flatnonzero(state_action_values[:, state] == max(state_action_values[:, state])))\n", + " return policy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "akjrncMF-FkU" + }, + "outputs": [], + "source": [ + "def markov_decision_process_step(state, transition_probabilities_given_action, reward_structure, terminal_states, action=None):\n", + " # Pick action\n", + " if action is None:\n", + " action = np.random.randint(4)\n", + " # Update the state\n", + " new_state = np.random.choice(a=range(transition_probabilities_given_action.shape[0]), p = transition_probabilities_given_action[:, state,action])\n", + "\n", + " # Return the reward -- here the reward is for arriving at the state\n", + " reward = reward_structure[new_state]\n", + " is_terminal = new_state in [terminal_states]\n", + "\n", + " return new_state, reward, action, is_terminal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5pO6-9ACWhiV" + }, + "outputs": [], + "source": [ + "def q_learning_step(state_action_values, reward, state, new_state, action, is_terminal, gamma, alpha = 0.1):\n", " # TODO -- write this function\n", " # Replace this line\n", " state_action_values_after = np.copy(state_action_values)\n", "\n", " return state_action_values_after" - ], - "metadata": { - "id": "5pO6-9ACWhiV" - }, - "execution_count": null, - "outputs": [] + ] }, { - "cell_type": "code", + "cell_type": "markdown", + "metadata": { + "id": "u4OHTTk176xO" + }, "source": [ - "# This takes a single step from an MDP which just has a completely random policy\n", - "def markov_decision_process_step(state, transition_probabilities_given_action, reward_structure):\n", - " # Pick action\n", - " action = np.random.randint(4)\n", - " # Update the state\n", - " new_state = np.random.choice(a=np.arange(0,transition_probabilities_given_action.shape[0]),p = transition_probabilities_given_action[:,state,action])\n", - " # Return the reward -- here the reward is for leaving the state\n", - " reward = reward_structure[state]\n", - "\n", - " return new_state, reward, action" - ], - "metadata": { - "id": "akjrncMF-FkU" - }, - "execution_count": null, - "outputs": [] + "Lets run this for a single Q-learning step" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Fu5_VjvbSwfJ" + }, + "outputs": [], "source": [ "# Initialize the state-action values to random numbers\n", "np.random.seed(0)\n", "n_state = transition_probabilities_given_action.shape[0]\n", "n_action = transition_probabilities_given_action.shape[2]\n", + "terminal_states=[15]\n", "state_action_values = np.random.normal(size=(n_action, n_state))\n", + "# Hard code value of termination state of finding fish to 0\n", + "state_action_values[:, terminal_states] = 0\n", "gamma = 0.9\n", "\n", - "policy = np.argmax(state_action_values, axis=0).astype(int)\n", + "policy = get_policy(state_action_values)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n", "\n", "# Now let's simulate a single Q-learning step\n", "initial_state = 9\n", - "print(\"Initial state = \", initial_state)\n", - "new_state, reward, action = markov_decision_process_step(initial_state, transition_probabilities_given_action, reward_structure)\n", - "print(\"Action = \", action)\n", - "print(\"New state = \", new_state)\n", - "print(\"Reward = \", reward)\n", + "print(\"Initial state =\",initial_state)\n", + "new_state, reward, action, is_terminal = markov_decision_process_step(initial_state, transition_probabilities_given_action, reward_structure, terminal_states)\n", + "print(\"Action =\",action)\n", + "print(\"New state =\",new_state)\n", + "print(\"Reward =\", reward)\n", "\n", - "state_action_values_after = q_learning_step(state_action_values, reward, initial_state, new_state, action, gamma)\n", + "state_action_values_after = q_learning_step(state_action_values, reward, initial_state, new_state, action, is_terminal, gamma)\n", "print(\"Your value:\",state_action_values_after[action, initial_state])\n", - "print(\"True value: 0.27650262412468796\")\n", + "print(\"True value: 0.3024718977397814\")\n", "\n", - "policy = np.argmax(state_action_values, axis=0).astype(int)\n", + "policy = get_policy(state_action_values)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values_after, rewards = reward_structure)\n" - ], - "metadata": { - "id": "Fu5_VjvbSwfJ" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "Now let's run this for a while and watch the policy improve" - ], "metadata": { "id": "Ogh0qucmb68J" - } + }, + "source": [ + "Now let's run this for a while (20000) steps and watch the policy improve" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "N6gFYifh76xO" + }, + "outputs": [], "source": [ "# Initialize the state-action values to random numbers\n", "np.random.seed(0)\n", - "n_state = transition_probabilities_given_action.shape[0]\n", + "n_state = transition_probabilities_given_action.shape[0]\n", "n_action = transition_probabilities_given_action.shape[2]\n", "state_action_values = np.random.normal(size=(n_action, n_state))\n", - "# Hard code termination state of finding fish\n", - "state_action_values[:,n_state-1] = 3.0\n", + "\n", + "# Hard code value of termination state of finding fish to 0\n", + "terminal_states = [15]\n", + "state_action_values[:, terminal_states] = 0\n", "gamma = 0.9\n", "\n", "# Draw the initial setup\n", - "policy = np.argmax(state_action_values, axis=0).astype(int)\n", + "print('Initial Policy:')\n", + "policy = get_policy(state_action_values)\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n", "\n", - "\n", - "state= np.random.randint(n_state-1)\n", + "state = np.random.randint(n_state-1)\n", "\n", "# Run for a number of iterations\n", - "for c_iter in range(10000):\n", - " new_state, reward, action = markov_decision_process_step(state, transition_probabilities_given_action, reward_structure)\n", - " state_action_values_after = q_learning_step(state_action_values, reward, state, new_state, action, gamma)\n", + "for c_iter in range(20000):\n", + " new_state, reward, action, is_terminal = markov_decision_process_step(state, transition_probabilities_given_action, reward_structure, terminal_states)\n", + " state_action_values_after = q_learning_step(state_action_values, reward, state, new_state, action, is_terminal, gamma)\n", + "\n", " # If in termination state, reset state randomly\n", - " if new_state==15:\n", - " state= np.random.randint(n_state-1)\n", + " if is_terminal:\n", + " state = np.random.randint(n_state-1)\n", " else:\n", " state = new_state\n", - " # Update the policy\n", - " state_action_values = np.copy(state_action_values_after)\n", - " policy = np.argmax(state_action_values, axis=0).astype(int)\n", "\n", + " # Update the policy\n", + " state_action_values = deepcopy(state_action_values_after)\n", + " policy = get_policy(state_action_values_after)\n", + "\n", + "print('Final Optimal Policy:')\n", "# Draw the final situation\n", "mdp_drawer = DrawMDP(n_rows, n_cols)\n", "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)" - ], + ] + }, + { + "cell_type": "markdown", "metadata": { - "id": "qQFhwVqPcCFH" + "id": "djPTKuDk76xO" }, + "source": [ + "Finally, lets run this for a **single** episode and visualize the penguin's actions" + ] + }, + { + "cell_type": "code", "execution_count": null, - "outputs": [] + "metadata": { + "id": "pWObQf2h76xO" + }, + "outputs": [], + "source": [ + "def get_one_episode(n_state, state_action_values, terminal_states, gamma):\n", + "\n", + " state = np.random.randint(n_state-1)\n", + "\n", + " # Create lists to store all the states seen and actions taken throughout the single episode\n", + " all_states = []\n", + " all_actions = []\n", + "\n", + " # Initalize episode termination flag\n", + " done = False\n", + " # Initialize counter for steps in the episode\n", + " steps = 0\n", + "\n", + " all_states.append(state)\n", + "\n", + " while not done:\n", + " steps += 1\n", + "\n", + " new_state, reward, action, is_terminal = markov_decision_process_step(state, transition_probabilities_given_action, reward_structure, terminal_states)\n", + " all_states.append(new_state)\n", + " all_actions.append(action)\n", + "\n", + " state_action_values_after = q_learning_step(state_action_values, reward, state, new_state, action, is_terminal, gamma)\n", + "\n", + " # If in termination state, reset state randomly\n", + " if is_terminal:\n", + " state = np.random.randint(n_state-1)\n", + " print(f'Episode Terminated at {steps} Steps')\n", + " # Set episode termination flag\n", + " done = True\n", + " else:\n", + " state = new_state\n", + "\n", + " # Update the policy\n", + " state_action_values = deepcopy(state_action_values_after)\n", + " policy = get_policy(state_action_values_after)\n", + "\n", + " return all_states, all_actions, policy, state_action_values\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P7cbCGT176xO" + }, + "outputs": [], + "source": [ + "def visualize_one_episode(states, actions):\n", + " # Define actions for visualization\n", + " acts = ['up', 'right', 'down', 'left']\n", + "\n", + " # Iterate over the states and actions\n", + " for i in range(len(states)):\n", + "\n", + " if i == 0:\n", + " print('Starting State:', states[i])\n", + "\n", + " elif i == len(states)-1:\n", + " print('Episode Done:', states[i])\n", + "\n", + " else:\n", + " print('State', states[i-1])\n", + " a = actions[i]\n", + " print('Action:', acts[a])\n", + " print('Next State:', states[i])\n", + "\n", + " # Visualize the current state using the MDP drawer\n", + " mdp_drawer.draw(layout, state=states[i], rewards=reward_structure, draw_state_index=True)\n", + " clear_output(True)\n", + "\n", + " # Pause for a short duration to allow observation\n", + " sleep(1.5)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cr98F8PT76xP" + }, + "outputs": [], + "source": [ + "# Initialize the state-action values to random numbers\n", + "np.random.seed(2)\n", + "n_state = transition_probabilities_given_action.shape[0]\n", + "n_action = transition_probabilities_given_action.shape[2]\n", + "state_action_values = np.random.normal(size=(n_action, n_state))\n", + "\n", + "# Hard code value of termination state of finding fish to 0\n", + "terminal_states = [15]\n", + "state_action_values[:, terminal_states] = 0\n", + "gamma = 0.9\n", + "\n", + "# Draw the initial setup\n", + "print('Initial Policy:')\n", + "policy = get_policy(state_action_values)\n", + "mdp_drawer = DrawMDP(n_rows, n_cols)\n", + "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n", + "\n", + "states, actions, policy, state_action_values = get_one_episode(n_state, state_action_values, terminal_states, gamma)\n", + "\n", + "print()\n", + "print('Final Optimal Policy:')\n", + "mdp_drawer = DrawMDP(n_rows, n_cols)\n", + "mdp_drawer.draw(layout, policy = policy, state_action_values = state_action_values, rewards = reward_structure)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5zBu1g3776xP" + }, + "outputs": [], + "source": [ + "visualize_one_episode(states, actions)" + ] } - ] + ], + "metadata": { + "colab": { + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file