Add files via upload
This commit is contained in:
@@ -1,33 +1,22 @@
|
|||||||
{
|
{
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"provenance": [],
|
|
||||||
"authorship_tag": "ABX9TyOSEQVqxE5KrXmsZVh9M3gq",
|
|
||||||
"include_colab_link": true
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"name": "python3",
|
|
||||||
"display_name": "Python 3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "view-in-github",
|
"colab_type": "text",
|
||||||
"colab_type": "text"
|
"id": "view-in-github"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_1_Latent_Variable_Models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_1_Latent_Variable_Models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "t9vk9Elugvmi"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# **Notebook 17.1: Latent variable models**\n",
|
"# **Notebook 17.1: Latent variable models**\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -36,72 +25,76 @@
|
|||||||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "t9vk9Elugvmi"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "OLComQyvCIJ7"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"import scipy\n",
|
"import scipy\n",
|
||||||
"from matplotlib.colors import ListedColormap\n",
|
"from matplotlib.colors import ListedColormap\n",
|
||||||
"from matplotlib import cm"
|
"from matplotlib import cm"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "OLComQyvCIJ7"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "IyVn-Gi-p7wf"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"We'll assume that our base distribution over the latent variables is a 1D standard normal so that\n",
|
"We'll assume that our base distribution over the latent variables is a 1D standard normal so that\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"Pr(z) = \\mbox{Norm}_{z}[0,1]\n",
|
"Pr(z) = \\text{Norm}_{z}[0,1]\n",
|
||||||
"\\end{equation}\n",
|
"\\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"As in figure 17.2, we'll assume that the output is two dimensional, we we need to define a function that maps from the 1D latent variable to two dimensions. Usually, we would use a neural network, but in this case, we'll just define an arbitrary relationship.\n",
|
"As in figure 17.2, we'll assume that the output is two dimensional, we we need to define a function that maps from the 1D latent variable to two dimensions. Usually, we would use a neural network, but in this case, we'll just define an arbitrary relationship.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{eqnarray}\n",
|
"\\begin{align}\n",
|
||||||
"x_{1} &=& 0.5\\cdot\\exp\\Bigl[\\sin\\bigl[2+ 3.675 z \\bigr]\\Bigr]\\\\\n",
|
"x_{1} &=& 0.5\\cdot\\exp\\Bigl[\\sin\\bigl[2+ 3.675 z \\bigr]\\Bigr]\\\\\n",
|
||||||
"x_{2} &=& \\sin\\bigl[2+ 2.85 z \\bigr]\n",
|
"x_{2} &=& \\sin\\bigl[2+ 2.85 z \\bigr]\n",
|
||||||
"\\end{eqnarray}"
|
"\\end{align}"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "IyVn-Gi-p7wf"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "ZIfQwhd-AV6L"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# The function that maps z to x1 and x2\n",
|
"# The function that maps z to x1 and x2\n",
|
||||||
"def f(z):\n",
|
"def f(z):\n",
|
||||||
" x_1 = np.exp(np.sin(2+z*3.675)) * 0.5\n",
|
" x_1 = np.exp(np.sin(2+z*3.675)) * 0.5\n",
|
||||||
" x_2 = np.cos(2+z*2.85)\n",
|
" x_2 = np.cos(2+z*2.85)\n",
|
||||||
" return x_1, x_2"
|
" return x_1, x_2"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "ZIfQwhd-AV6L"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"Let's plot the 3D relation between the two observed variables $x_{1}$ and $x_{2}$ and the latent variables $z$ as in figure 17.2 of the book. We'll use the opacity to represent the prior probability $Pr(z)$."
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "KB9FU34onW1j"
|
"id": "KB9FU34onW1j"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"Let's plot the 3D relation between the two observed variables $x_{1}$ and $x_{2}$ and the latent variables $z$ as in figure 17.2 of the book. We'll use the opacity to represent the prior probability $Pr(z)$."
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "lW08xqAgnP4q"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def draw_3d_projection(z,pr_z, x1,x2):\n",
|
"def draw_3d_projection(z,pr_z, x1,x2):\n",
|
||||||
" alpha = pr_z / np.max(pr_z)\n",
|
" alpha = pr_z / np.max(pr_z)\n",
|
||||||
@@ -118,28 +111,28 @@
|
|||||||
" ax.set_zlim(-1,1)\n",
|
" ax.set_zlim(-1,1)\n",
|
||||||
" ax.set_box_aspect((3,1,1))\n",
|
" ax.set_box_aspect((3,1,1))\n",
|
||||||
" plt.show()"
|
" plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "lW08xqAgnP4q"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "9DUTauMi6tPk"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Compute the prior\n",
|
"# Compute the prior\n",
|
||||||
"def get_prior(z):\n",
|
"def get_prior(z):\n",
|
||||||
" return scipy.stats.multivariate_normal.pdf(z)"
|
" return scipy.stats.multivariate_normal.pdf(z)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "9DUTauMi6tPk"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "PAzHq461VqvF"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Define the latent variable values\n",
|
"# Define the latent variable values\n",
|
||||||
"z = np.arange(-3.0,3.0,0.01)\n",
|
"z = np.arange(-3.0,3.0,0.01)\n",
|
||||||
@@ -149,40 +142,41 @@
|
|||||||
"x1,x2 = f(z)\n",
|
"x1,x2 = f(z)\n",
|
||||||
"# Plot the function\n",
|
"# Plot the function\n",
|
||||||
"draw_3d_projection(z,pr_z, x1,x2)"
|
"draw_3d_projection(z,pr_z, x1,x2)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "PAzHq461VqvF"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"The likelihood is defined as:\n",
|
|
||||||
"\\begin{eqnarray}\n",
|
|
||||||
" Pr(x_1,x_2|z) &=& \\mbox{Norm}_{[x_1,x_2]}\\Bigl[\\mathbf{f}[z],\\sigma^{2}\\mathbf{I}\\Bigr]\n",
|
|
||||||
"\\end{eqnarray}\n",
|
|
||||||
"\n",
|
|
||||||
"so we will also need to define the noise level $\\sigma^2$"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "sQg2gKR5zMrF"
|
"id": "sQg2gKR5zMrF"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"The likelihood is defined as:\n",
|
||||||
|
"\\begin{align}\n",
|
||||||
|
" Pr(x_1,x_2|z) &=& \\text{Norm}_{[x_1,x_2]}\\Bigl[\\mathbf{f}[z],\\sigma^{2}\\mathbf{I}\\Bigr]\n",
|
||||||
|
"\\end{align}\n",
|
||||||
|
"\n",
|
||||||
|
"so we will also need to define the noise level $\\sigma^2$"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"execution_count": null,
|
||||||
"sigma_sq = 0.04"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "In_Vg4_0nva3"
|
"id": "In_Vg4_0nva3"
|
||||||
},
|
},
|
||||||
"execution_count": null,
|
"outputs": [],
|
||||||
"outputs": []
|
"source": [
|
||||||
|
"sigma_sq = 0.04"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "6P6d-AgAqxXZ"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Draws a heatmap to represent a probability distribution, possibly with samples overlaed\n",
|
"# Draws a heatmap to represent a probability distribution, possibly with samples overlaed\n",
|
||||||
"def plot_heatmap(x1_mesh,x2_mesh,y_mesh, x1_samples=None, x2_samples=None, title=None):\n",
|
"def plot_heatmap(x1_mesh,x2_mesh,y_mesh, x1_samples=None, x2_samples=None, title=None):\n",
|
||||||
@@ -207,15 +201,15 @@
|
|||||||
" ax.set_xlabel('$x_1$'); ax.set_ylabel('$x_2$')\n",
|
" ax.set_xlabel('$x_1$'); ax.set_ylabel('$x_2$')\n",
|
||||||
" plt.show()\n",
|
" plt.show()\n",
|
||||||
"\n"
|
"\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "6P6d-AgAqxXZ"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "diYKb7_ZgjlJ"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Returns the likelihood\n",
|
"# Returns the likelihood\n",
|
||||||
"def get_likelihood(x1_mesh, x2_mesh, z_val):\n",
|
"def get_likelihood(x1_mesh, x2_mesh, z_val):\n",
|
||||||
@@ -226,24 +220,25 @@
|
|||||||
" mn = scipy.stats.multivariate_normal([x1, x2], [[sigma_sq, 0], [0, sigma_sq]])\n",
|
" mn = scipy.stats.multivariate_normal([x1, x2], [[sigma_sq, 0], [0, sigma_sq]])\n",
|
||||||
" pr_x1_x2_given_z_val = mn.pdf(np.dstack((x1_mesh, x2_mesh)))\n",
|
" pr_x1_x2_given_z_val = mn.pdf(np.dstack((x1_mesh, x2_mesh)))\n",
|
||||||
" return pr_x1_x2_given_z_val"
|
" return pr_x1_x2_given_z_val"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "diYKb7_ZgjlJ"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"Now let's plot the likelihood $Pr(x_1,x_2|z)$ as in fig 17.3b in the book."
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "0X4NwixzqxtZ"
|
"id": "0X4NwixzqxtZ"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"Now let's plot the likelihood $Pr(x_1,x_2|z)$ as in fig 17.3b in the book."
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "hWfqK-Oz5_DT"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Choose some z value\n",
|
"# Choose some z value\n",
|
||||||
"z_val = 1.8\n",
|
"z_val = 1.8\n",
|
||||||
@@ -256,30 +251,31 @@
|
|||||||
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2_given_z_val, title=\"Conditional distribution $Pr(x_1,x_2|z)$\")\n",
|
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2_given_z_val, title=\"Conditional distribution $Pr(x_1,x_2|z)$\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# TODO -- Experiment with different values of z and make sure that you understand the what is happening."
|
"# TODO -- Experiment with different values of z and make sure that you understand the what is happening."
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "hWfqK-Oz5_DT"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "25xqXnmFo-PH"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"The data density is found by marginalizing over the latent variables $z$:\n",
|
"The data density is found by marginalizing over the latent variables $z$:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{eqnarray}\n",
|
"\\begin{align}\n",
|
||||||
" Pr(x_1,x_2) &=& \\int Pr(x_1,x_2, z) dz \\nonumber \\\\\n",
|
" Pr(x_1,x_2) &=& \\int Pr(x_1,x_2, z) dz \\nonumber \\\\\n",
|
||||||
" &=& \\int Pr(x_1,x_2 | z) \\cdot Pr(z)dz\\nonumber \\\\\n",
|
" &=& \\int Pr(x_1,x_2 | z) \\cdot Pr(z)dz\\nonumber \\\\\n",
|
||||||
" &=& \\int \\mbox{Norm}_{[x_1,x_2]}\\Bigl[\\mathbf{f}[z],\\sigma^{2}\\mathbf{I}\\Bigr]\\cdot \\mbox{Norm}_{z}\\left[\\mathbf{0},\\mathbf{I}\\right]dz.\n",
|
" &=& \\int \\text{Norm}_{[x_1,x_2]}\\Bigl[\\mathbf{f}[z],\\sigma^{2}\\mathbf{I}\\Bigr]\\cdot \\text{Norm}_{z}\\left[\\mathbf{0},\\mathbf{I}\\right]dz.\n",
|
||||||
"\\end{eqnarray}"
|
"\\end{align}"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "25xqXnmFo-PH"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "H0Ijce9VzeCO"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# TODO Compute the data density\n",
|
"# TODO Compute the data density\n",
|
||||||
"# We can't integrate this function in closed form\n",
|
"# We can't integrate this function in closed form\n",
|
||||||
@@ -293,24 +289,25 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Plot the result\n",
|
"# Plot the result\n",
|
||||||
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2, title=\"Data density $Pr(x_1,x_2)$\")\n"
|
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2, title=\"Data density $Pr(x_1,x_2)$\")\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "H0Ijce9VzeCO"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"Now let's draw some samples from the model"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "W264N7By_h9y"
|
"id": "W264N7By_h9y"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"Now let's draw some samples from the model"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "Li3mK_I48k0k"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def draw_samples(n_sample):\n",
|
"def draw_samples(n_sample):\n",
|
||||||
" # TODO Write this routine to draw n_sample samples from the model\n",
|
" # TODO Write this routine to draw n_sample samples from the model\n",
|
||||||
@@ -320,37 +317,38 @@
|
|||||||
" x1_samples=0; x2_samples = 0;\n",
|
" x1_samples=0; x2_samples = 0;\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return x1_samples, x2_samples"
|
" return x1_samples, x2_samples"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "Li3mK_I48k0k"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"Let's plot those samples on top of the heat map."
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "D7N7oqLe-eJO"
|
"id": "D7N7oqLe-eJO"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"Let's plot those samples on top of the heat map."
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "XRmWv99B-BWO"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"x1_samples, x2_samples = draw_samples(500)\n",
|
"x1_samples, x2_samples = draw_samples(500)\n",
|
||||||
"# Plot the result\n",
|
"# Plot the result\n",
|
||||||
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2, x1_samples, x2_samples, title=\"Data density $Pr(x_1,x_2)$\")\n"
|
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2, x1_samples, x2_samples, title=\"Data density $Pr(x_1,x_2)$\")\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "XRmWv99B-BWO"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "PwOjzPD5_1OF"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Return the posterior distribution\n",
|
"# Return the posterior distribution\n",
|
||||||
"def get_posterior(x1,x2):\n",
|
"def get_posterior(x1,x2):\n",
|
||||||
@@ -364,15 +362,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return z, pr_z_given_x1_x2"
|
" return z, pr_z_given_x1_x2"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "PwOjzPD5_1OF"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "PKFUY42K-Tp7"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"x1 = 0.9; x2 = -0.9\n",
|
"x1 = 0.9; x2 = -0.9\n",
|
||||||
"z, pr_z_given_x1_x2 = get_posterior(x1,x2)\n",
|
"z, pr_z_given_x1_x2 = get_posterior(x1,x2)\n",
|
||||||
@@ -385,12 +383,23 @@
|
|||||||
"ax.set_xlim([-3,3])\n",
|
"ax.set_xlim([-3,3])\n",
|
||||||
"ax.set_ylim([0,1.5 * np.max(pr_z_given_x1_x2)])\n",
|
"ax.set_ylim([0,1.5 * np.max(pr_z_given_x1_x2)])\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "PKFUY42K-Tp7"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"authorship_tag": "ABX9TyOSEQVqxE5KrXmsZVh9M3gq",
|
||||||
|
"include_colab_link": true,
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
@@ -1,33 +1,22 @@
|
|||||||
{
|
{
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"provenance": [],
|
|
||||||
"authorship_tag": "ABX9TyOxO2/0DTH4n4zhC97qbagY",
|
|
||||||
"include_colab_link": true
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"name": "python3",
|
|
||||||
"display_name": "Python 3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "view-in-github",
|
"colab_type": "text",
|
||||||
"colab_type": "text"
|
"id": "view-in-github"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_2_Reparameterization_Trick.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_2_Reparameterization_Trick.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "t9vk9Elugvmi"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# **Notebook 17.2: Reparameterization trick**\n",
|
"# **Notebook 17.2: Reparameterization trick**\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -36,30 +25,31 @@
|
|||||||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "t9vk9Elugvmi"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"execution_count": null,
|
||||||
"import numpy as np\n",
|
|
||||||
"import matplotlib.pyplot as plt"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "OLComQyvCIJ7"
|
"id": "OLComQyvCIJ7"
|
||||||
},
|
},
|
||||||
"execution_count": null,
|
"outputs": [],
|
||||||
"outputs": []
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import matplotlib.pyplot as plt"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "paLz5RukZP1J"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"The reparameterization trick computes the derivative of an expectation of a function $\\mbox{f}[x]$:\n",
|
"The reparameterization trick computes the derivative of an expectation of a function $\\text{f}[x]$:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr],\n",
|
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\text{f}[x]\\bigr],\n",
|
||||||
"\\end{equation}\n",
|
"\\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over.\n",
|
"with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over.\n",
|
||||||
@@ -67,21 +57,23 @@
|
|||||||
"Let's consider a simple concrete example, where:\n",
|
"Let's consider a simple concrete example, where:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"Pr(x|\\phi) = \\mbox{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\mbox{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr]\n",
|
"Pr(x|\\phi) = \\text{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\text{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr]\n",
|
||||||
"\\end{equation}\n",
|
"\\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"and\n",
|
"and\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"\\mbox{f}[x] = x^2+\\sin[x]\n",
|
"\\text{f}[x] = x^2+\\sin[x]\n",
|
||||||
"\\end{equation}"
|
"\\end{equation}"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "paLz5RukZP1J"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "FdEbMnDBY0i9"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Let's approximate this expectation for a particular value of phi\n",
|
"# Let's approximate this expectation for a particular value of phi\n",
|
||||||
"def compute_expectation(phi, n_samples):\n",
|
"def compute_expectation(phi, n_samples):\n",
|
||||||
@@ -96,15 +88,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return expected_f_given_phi"
|
" return expected_f_given_phi"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "FdEbMnDBY0i9"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "FTh7LJ0llNJZ"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Set the seed so the random numbers are all the same\n",
|
"# Set the seed so the random numbers are all the same\n",
|
||||||
"np.random.seed(0)\n",
|
"np.random.seed(0)\n",
|
||||||
@@ -119,24 +111,25 @@
|
|||||||
"n_samples = 10000000\n",
|
"n_samples = 10000000\n",
|
||||||
"expected_f_given_phi2 = compute_expectation(phi2, n_samples)\n",
|
"expected_f_given_phi2 = compute_expectation(phi2, n_samples)\n",
|
||||||
"print(\"Your value: \", expected_f_given_phi2, \", True value: 0.8176793102849222\")"
|
"print(\"Your value: \", expected_f_given_phi2, \", True value: 0.8176793102849222\")"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "FTh7LJ0llNJZ"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"Le't plot this expectation as a function of phi"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "r5Hl2QkimWx9"
|
"id": "r5Hl2QkimWx9"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"Le't plot this expectation as a function of phi"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "05XxVLJxmkER"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
|
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
|
||||||
"expected_vals = np.zeros_like(phi_vals)\n",
|
"expected_vals = np.zeros_like(phi_vals)\n",
|
||||||
@@ -149,15 +142,14 @@
|
|||||||
"ax.set_xlabel('Parameter $\\phi$')\n",
|
"ax.set_xlabel('Parameter $\\phi$')\n",
|
||||||
"ax.set_ylabel('$\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
|
"ax.set_ylabel('$\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "05XxVLJxmkER"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "zTCykVeWqj_O"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"It's this curve that we want to find the derivative of (so for example, we could run gradient descent and find the minimum.\n",
|
"It's this curve that we want to find the derivative of (so for example, we could run gradient descent and find the minimum.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -166,28 +158,30 @@
|
|||||||
"The answer is the reparameterization trick. We note that:\n",
|
"The answer is the reparameterization trick. We note that:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"\\mbox{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\mbox{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\sigma + \\mu\n",
|
"\\text{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\text{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\sigma + \\mu\n",
|
||||||
"\\end{equation}\n",
|
"\\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"and so:\n",
|
"and so:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"\\mbox{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr] = \\mbox{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\exp[\\phi]+ \\phi^3\n",
|
"\\text{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr] = \\text{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\exp[\\phi]+ \\phi^3\n",
|
||||||
"\\end{equation}\n",
|
"\\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"So, if we draw a sample $\\epsilon^*$ from $\\mbox{Norm}_{\\epsilon}[0, 1]$, then we can compute a sample $x^*$ as:\n",
|
"So, if we draw a sample $\\epsilon^*$ from $\\text{Norm}_{\\epsilon}[0, 1]$, then we can compute a sample $x^*$ as:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{eqnarray*}\n",
|
"\\begin{align}\n",
|
||||||
"x^* &=& \\epsilon^* \\times \\sigma + \\mu \\\\\n",
|
"x^* &=& \\epsilon^* \\times \\sigma + \\mu \\\\\n",
|
||||||
"&=& \\epsilon^* \\times \\exp[\\phi]+ \\phi^3\n",
|
"&=& \\epsilon^* \\times \\exp[\\phi]+ \\phi^3\n",
|
||||||
"\\end{eqnarray*}"
|
"\\end{align}"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "zTCykVeWqj_O"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "w13HVpi9q8nF"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def compute_df_dx_star(x_star):\n",
|
"def compute_df_dx_star(x_star):\n",
|
||||||
" # TODO Compute this derivative (function defined at the top)\n",
|
" # TODO Compute this derivative (function defined at the top)\n",
|
||||||
@@ -222,15 +216,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return df_dphi"
|
" return df_dphi"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "w13HVpi9q8nF"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "ntQT4An79kAl"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Set the seed so the random numbers are all the same\n",
|
"# Set the seed so the random numbers are all the same\n",
|
||||||
"np.random.seed(0)\n",
|
"np.random.seed(0)\n",
|
||||||
@@ -241,15 +235,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"deriv = compute_derivative_of_expectation(phi1, n_samples)\n",
|
"deriv = compute_derivative_of_expectation(phi1, n_samples)\n",
|
||||||
"print(\"Your value: \", deriv, \", True value: 5.726338035051403\")"
|
"print(\"Your value: \", deriv, \", True value: 5.726338035051403\")"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "ntQT4An79kAl"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "t0Jqd_IN_lMU"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
|
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
|
||||||
"deriv_vals = np.zeros_like(phi_vals)\n",
|
"deriv_vals = np.zeros_like(phi_vals)\n",
|
||||||
@@ -262,37 +256,37 @@
|
|||||||
"ax.set_xlabel('Parameter $\\phi$')\n",
|
"ax.set_xlabel('Parameter $\\phi$')\n",
|
||||||
"ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
|
"ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "t0Jqd_IN_lMU"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"This should look plausibly like the derivative of the function we plotted above!"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "ASu4yKSwAEYI"
|
"id": "ASu4yKSwAEYI"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"This should look plausibly like the derivative of the function we plotted above!"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "xoFR1wifc8-b"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"The reparameterization trick computes the derivative of an expectation of a function $\\mbox{f}[x]$:\n",
|
"The reparameterization trick computes the derivative of an expectation of a function $\\text{f}[x]$:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr],\n",
|
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\text{f}[x]\\bigr],\n",
|
||||||
"\\end{equation}\n",
|
"\\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over. This derivative can also be computed as:\n",
|
"with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over. This derivative can also be computed as:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{eqnarray}\n",
|
"\\begin{align}\n",
|
||||||
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr] &=& \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\left[\\mbox{f}[x]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x|\\boldsymbol\\phi)\\bigr]\\right]\\nonumber \\\\\n",
|
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\text{f}[x]\\bigr] &=& \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\left[\\text{f}[x]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x|\\boldsymbol\\phi)\\bigr]\\right]\\nonumber \\\\\n",
|
||||||
"&\\approx & \\frac{1}{I}\\sum_{i=1}^{I}\\mbox{f}[x_i]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x_i|\\boldsymbol\\phi)\\bigr].\n",
|
"&\\approx & \\frac{1}{I}\\sum_{i=1}^{I}\\text{f}[x_i]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x_i|\\boldsymbol\\phi)\\bigr].\n",
|
||||||
"\\end{eqnarray}\n",
|
"\\end{align}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This method is known as the REINFORCE algorithm or score function estimator. Problem 17.5 asks you to prove this relation. Let's use this method to compute the gradient and compare.\n",
|
"This method is known as the REINFORCE algorithm or score function estimator. Problem 17.5 asks you to prove this relation. Let's use this method to compute the gradient and compare.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -301,13 +295,15 @@
|
|||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
" Pr(x|\\mu,\\sigma^2) = \\frac{1}{\\sqrt{2\\pi\\sigma^{2}}}\\exp\\left[-\\frac{(x-\\mu)^{2}}{2\\sigma^{2}}\\right].\n",
|
" Pr(x|\\mu,\\sigma^2) = \\frac{1}{\\sqrt{2\\pi\\sigma^{2}}}\\exp\\left[-\\frac{(x-\\mu)^{2}}{2\\sigma^{2}}\\right].\n",
|
||||||
"\\end{equation}\n"
|
"\\end{equation}\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "xoFR1wifc8-b"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "4TUaxiWvASla"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def d_log_pr_x_given_phi(x,phi):\n",
|
"def d_log_pr_x_given_phi(x,phi):\n",
|
||||||
" # TODO -- fill in this function\n",
|
" # TODO -- fill in this function\n",
|
||||||
@@ -333,15 +329,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return deriv"
|
" return deriv"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "4TUaxiWvASla"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "0RSN32Rna_C_"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Set the seed so the random numbers are all the same\n",
|
"# Set the seed so the random numbers are all the same\n",
|
||||||
"np.random.seed(0)\n",
|
"np.random.seed(0)\n",
|
||||||
@@ -352,15 +348,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"deriv = compute_derivative_of_expectation_score_function(phi1, n_samples)\n",
|
"deriv = compute_derivative_of_expectation_score_function(phi1, n_samples)\n",
|
||||||
"print(\"Your value: \", deriv, \", True value: 5.724609927313369\")"
|
"print(\"Your value: \", deriv, \", True value: 5.724609927313369\")"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "0RSN32Rna_C_"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "EM_i5zoyElHR"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
|
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
|
||||||
"deriv_vals = np.zeros_like(phi_vals)\n",
|
"deriv_vals = np.zeros_like(phi_vals)\n",
|
||||||
@@ -373,24 +369,25 @@
|
|||||||
"ax.set_xlabel('Parameter $\\phi$')\n",
|
"ax.set_xlabel('Parameter $\\phi$')\n",
|
||||||
"ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
|
"ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "EM_i5zoyElHR"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"This should look the same as the derivative that we computed with the reparameterization trick. So, is there any advantage to one way or the other? Let's compare the variances of the estimates\n"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "1TWBiUC7bQSw"
|
"id": "1TWBiUC7bQSw"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"This should look the same as the derivative that we computed with the reparameterization trick. So, is there any advantage to one way or the other? Let's compare the variances of the estimates\n"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "vV_Jx5bCbQGs"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"n_estimate = 100\n",
|
"n_estimate = 100\n",
|
||||||
"n_sample = 1000\n",
|
"n_sample = 1000\n",
|
||||||
@@ -403,21 +400,33 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"print(\"Variance of reparameterization estimator\", np.var(reparam_estimates))\n",
|
"print(\"Variance of reparameterization estimator\", np.var(reparam_estimates))\n",
|
||||||
"print(\"Variance of score function estimator\", np.var(score_function_estimates))"
|
"print(\"Variance of score function estimator\", np.var(score_function_estimates))"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "vV_Jx5bCbQGs"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"The variance of the reparameterization estimator should be quite a bit lower than the score function estimator which is why it is preferred in this situation."
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "d-0tntSYdKPR"
|
"id": "d-0tntSYdKPR"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"The variance of the reparameterization estimator should be quite a bit lower than the score function estimator which is why it is preferred in this situation."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"authorship_tag": "ABX9TyOxO2/0DTH4n4zhC97qbagY",
|
||||||
|
"include_colab_link": true,
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
@@ -1,33 +1,22 @@
|
|||||||
{
|
{
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"provenance": [],
|
|
||||||
"authorship_tag": "ABX9TyNecz9/CDOggPSmy1LjT/Dv",
|
|
||||||
"include_colab_link": true
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"name": "python3",
|
|
||||||
"display_name": "Python 3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "view-in-github",
|
"colab_type": "text",
|
||||||
"colab_type": "text"
|
"id": "view-in-github"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_3_Importance_Sampling.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_3_Importance_Sampling.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "t9vk9Elugvmi"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# **Notebook 17.3: Importance sampling**\n",
|
"# **Notebook 17.3: Importance sampling**\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -36,25 +25,26 @@
|
|||||||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "t9vk9Elugvmi"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"execution_count": null,
|
||||||
"import numpy as np\n",
|
|
||||||
"import matplotlib.pyplot as plt"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "OLComQyvCIJ7"
|
"id": "OLComQyvCIJ7"
|
||||||
},
|
},
|
||||||
"execution_count": null,
|
"outputs": [],
|
||||||
"outputs": []
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import matplotlib.pyplot as plt"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "f7a6xqKjkmvT"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Let's approximate the expectation\n",
|
"Let's approximate the expectation\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -65,7 +55,7 @@
|
|||||||
"where\n",
|
"where\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"Pr(y)=\\mbox{Norm}_y[0,1]\n",
|
"Pr(y)=\\text{Norm}_y[0,1]\n",
|
||||||
"\\end{equation}\n",
|
"\\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"by drawing $I$ samples $y_i$ and using the formula:\n",
|
"by drawing $I$ samples $y_i$ and using the formula:\n",
|
||||||
@@ -73,13 +63,15 @@
|
|||||||
"\\begin{equation}\n",
|
"\\begin{equation}\n",
|
||||||
"\\mathbb{E}_{y}\\Bigl[\\exp\\bigl[- (y-1)^4\\bigr]\\Bigr] \\approx \\frac{1}{I} \\sum_{i=1}^I \\exp\\bigl[-(y-1)^4 \\bigr]\n",
|
"\\mathbb{E}_{y}\\Bigl[\\exp\\bigl[- (y-1)^4\\bigr]\\Bigr] \\approx \\frac{1}{I} \\sum_{i=1}^I \\exp\\bigl[-(y-1)^4 \\bigr]\n",
|
||||||
"\\end{equation}"
|
"\\end{equation}"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "f7a6xqKjkmvT"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "VjkzRr8o2ksg"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def f(y):\n",
|
"def f(y):\n",
|
||||||
" return np.exp(-(y-1) *(y-1) *(y-1) * (y-1))\n",
|
" return np.exp(-(y-1) *(y-1) *(y-1) * (y-1))\n",
|
||||||
@@ -95,15 +87,15 @@
|
|||||||
"ax.set_xlabel(\"$y$\")\n",
|
"ax.set_xlabel(\"$y$\")\n",
|
||||||
"ax.legend()\n",
|
"ax.legend()\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "VjkzRr8o2ksg"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "LGAKHjUJnWmy"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def compute_expectation(n_samples):\n",
|
"def compute_expectation(n_samples):\n",
|
||||||
" # TODO -- compute this expectation\n",
|
" # TODO -- compute this expectation\n",
|
||||||
@@ -114,15 +106,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return expectation"
|
" return expectation"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "LGAKHjUJnWmy"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "nmvixMqgodIP"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Set the seed so the random numbers are all the same\n",
|
"# Set the seed so the random numbers are all the same\n",
|
||||||
"np.random.seed(0)\n",
|
"np.random.seed(0)\n",
|
||||||
@@ -131,26 +123,27 @@
|
|||||||
"n_samples = 100000000\n",
|
"n_samples = 100000000\n",
|
||||||
"expected_f= compute_expectation(n_samples)\n",
|
"expected_f= compute_expectation(n_samples)\n",
|
||||||
"print(\"Your value: \", expected_f, \", True value: 0.43160702267383166\")"
|
"print(\"Your value: \", expected_f, \", True value: 0.43160702267383166\")"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "nmvixMqgodIP"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "Jr4UPcqmnXCS"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Let's investigate how the variance of this approximation decreases as we increase the number of samples $N$.\n",
|
"Let's investigate how the variance of this approximation decreases as we increase the number of samples $N$.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n"
|
"\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "Jr4UPcqmnXCS"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "yrDp1ILUo08j"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def compute_mean_variance(n_sample):\n",
|
"def compute_mean_variance(n_sample):\n",
|
||||||
" n_estimate = 10000\n",
|
" n_estimate = 10000\n",
|
||||||
@@ -158,15 +151,15 @@
|
|||||||
" for i in range(n_estimate):\n",
|
" for i in range(n_estimate):\n",
|
||||||
" estimates[i] = compute_expectation(n_sample.astype(int))\n",
|
" estimates[i] = compute_expectation(n_sample.astype(int))\n",
|
||||||
" return np.mean(estimates), np.var(estimates)"
|
" return np.mean(estimates), np.var(estimates)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "yrDp1ILUo08j"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "BcUVsodtqdey"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Compute the mean and variance for 1,2,... 20 samples\n",
|
"# Compute the mean and variance for 1,2,... 20 samples\n",
|
||||||
"n_sample_all = np.array([1.,2,3,4,5,6,7,8,9,10,15,20,25,30,45,50,60,70,80,90,100,150,200,250,300,350,400,450,500])\n",
|
"n_sample_all = np.array([1.,2,3,4,5,6,7,8,9,10,15,20,25,30,45,50,60,70,80,90,100,150,200,250,300,350,400,450,500])\n",
|
||||||
@@ -175,15 +168,15 @@
|
|||||||
"for i in range(len(n_sample_all)):\n",
|
"for i in range(len(n_sample_all)):\n",
|
||||||
" print(\"Computing mean and variance for expectation with %d samples\"%(n_sample_all[i]))\n",
|
" print(\"Computing mean and variance for expectation with %d samples\"%(n_sample_all[i]))\n",
|
||||||
" mean_all[i],variance_all[i] = compute_mean_variance(n_sample_all[i])"
|
" mean_all[i],variance_all[i] = compute_mean_variance(n_sample_all[i])"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "BcUVsodtqdey"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "feXmyk0krpUi"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"fig,ax = plt.subplots()\n",
|
"fig,ax = plt.subplots()\n",
|
||||||
"ax.semilogx(n_sample_all, mean_all,'r-',label='mean estimate')\n",
|
"ax.semilogx(n_sample_all, mean_all,'r-',label='mean estimate')\n",
|
||||||
@@ -193,24 +186,24 @@
|
|||||||
"ax.plot([0,500],[0.43160702267383166,0.43160702267383166],'k--',label='true value')\n",
|
"ax.plot([0,500],[0.43160702267383166,0.43160702267383166],'k--',label='true value')\n",
|
||||||
"ax.legend()\n",
|
"ax.legend()\n",
|
||||||
"plt.show()\n"
|
"plt.show()\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "feXmyk0krpUi"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"As you might expect, the more samples that we use to compute the approximate estimate, the lower the variance of the estimate."
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "XTUpxFlSuOl7"
|
"id": "XTUpxFlSuOl7"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"As you might expect, the more samples that we use to compute the approximate estimate, the lower the variance of the estimate."
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "6hxsl3Pxo1TT"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
" Now consider the function\n",
|
" Now consider the function\n",
|
||||||
" \\begin{equation}\n",
|
" \\begin{equation}\n",
|
||||||
@@ -218,13 +211,15 @@
|
|||||||
" \\end{equation}\n",
|
" \\end{equation}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"which decreases rapidly as we move away from the position $y=3$."
|
"which decreases rapidly as we move away from the position $y=3$."
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "6hxsl3Pxo1TT"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "znydVPW7sL4P"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def f2(y):\n",
|
"def f2(y):\n",
|
||||||
" return 20.446*np.exp(- (y-3) *(y-3) *(y-3) * (y-3))\n",
|
" return 20.446*np.exp(- (y-3) *(y-3) *(y-3) * (y-3))\n",
|
||||||
@@ -236,46 +231,47 @@
|
|||||||
"ax.set_xlabel(\"$y$\")\n",
|
"ax.set_xlabel(\"$y$\")\n",
|
||||||
"ax.legend()\n",
|
"ax.legend()\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "znydVPW7sL4P"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "G9Xxo0OJsIqD"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Let's again, compute the expectation:\n",
|
"Let's again, compute the expectation:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\\begin{eqnarray}\n",
|
"\\begin{align}\n",
|
||||||
"\\mathbb{E}_{y}\\left[\\mbox{f}[y]\\right] &=& \\int \\mbox{f}[y] Pr(y) dy\\\\\n",
|
"\\mathbb{E}_{y}\\left[\\text{f}[y]\\right] &=& \\int \\text{f}[y] Pr(y) dy\\\\\n",
|
||||||
"&\\approx& \\frac{1}{I} \\mbox{f}[y]\n",
|
"&\\approx& \\frac{1}{I} \\text{f}[y]\n",
|
||||||
"\\end{eqnarray}\n",
|
"\\end{align}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"where $Pr(y)=\\mbox{Norm}_y[0,1]$ by approximating with samples $y_{i}$.\n"
|
"where $Pr(y)=\\text{Norm}_y[0,1]$ by approximating with samples $y_{i}$.\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "G9Xxo0OJsIqD"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "l8ZtmkA2vH4y"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def compute_expectation2(n_samples):\n",
|
"def compute_expectation2(n_samples):\n",
|
||||||
" y = np.random.normal(size=(n_samples,1))\n",
|
" y = np.random.normal(size=(n_samples,1))\n",
|
||||||
" expectation = np.mean(f2(y))\n",
|
" expectation = np.mean(f2(y))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return expectation"
|
" return expectation"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "l8ZtmkA2vH4y"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "dfUQyJ-svZ6F"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Set the seed so the random numbers are all the same\n",
|
"# Set the seed so the random numbers are all the same\n",
|
||||||
"np.random.seed(0)\n",
|
"np.random.seed(0)\n",
|
||||||
@@ -284,26 +280,27 @@
|
|||||||
"n_samples = 100000000\n",
|
"n_samples = 100000000\n",
|
||||||
"expected_f2= compute_expectation2(n_samples)\n",
|
"expected_f2= compute_expectation2(n_samples)\n",
|
||||||
"print(\"Expected value: \", expected_f2)"
|
"print(\"Expected value: \", expected_f2)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "dfUQyJ-svZ6F"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "2sVDqP0BvxqM"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"I deliberately chose this function, because it's expectation is roughly the same as for the previous function.\n",
|
"I deliberately chose this function, because it's expectation is roughly the same as for the previous function.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Again, let's look at the mean and the variance of the estimates"
|
"Again, let's look at the mean and the variance of the estimates"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "2sVDqP0BvxqM"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "mHnILRkOv0Ir"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def compute_mean_variance2(n_sample):\n",
|
"def compute_mean_variance2(n_sample):\n",
|
||||||
" n_estimate = 10000\n",
|
" n_estimate = 10000\n",
|
||||||
@@ -318,15 +315,15 @@
|
|||||||
"for i in range(len(n_sample_all)):\n",
|
"for i in range(len(n_sample_all)):\n",
|
||||||
" print(\"Computing variance for expectation with %d samples\"%(n_sample_all[i]))\n",
|
" print(\"Computing variance for expectation with %d samples\"%(n_sample_all[i]))\n",
|
||||||
" mean_all2[i], variance_all2[i] = compute_mean_variance2(n_sample_all[i])"
|
" mean_all2[i], variance_all2[i] = compute_mean_variance2(n_sample_all[i])"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "mHnILRkOv0Ir"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "FkCX-hxxAnsw"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"fig,ax1 = plt.subplots()\n",
|
"fig,ax1 = plt.subplots()\n",
|
||||||
"ax1.semilogx(n_sample_all, mean_all,'r-',label='mean estimate')\n",
|
"ax1.semilogx(n_sample_all, mean_all,'r-',label='mean estimate')\n",
|
||||||
@@ -348,39 +345,41 @@
|
|||||||
"ax2.set_title(\"Second function\")\n",
|
"ax2.set_title(\"Second function\")\n",
|
||||||
"ax2.legend()\n",
|
"ax2.legend()\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "FkCX-hxxAnsw"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "EtBP6NeLwZqz"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"You can see that the variance of the estimate of the second function is considerably worse than the estimate of the variance of estimate of the first function\n",
|
"You can see that the variance of the estimate of the second function is considerably worse than the estimate of the variance of estimate of the first function\n",
|
||||||
"\n",
|
"\n",
|
||||||
"TODO: Think about why this is."
|
"TODO: Think about why this is."
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "EtBP6NeLwZqz"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "_wuF-NoQu1--"
|
||||||
|
},
|
||||||
"source": [
|
"source": [
|
||||||
" Now let's repeat this experiment with the second function, but this time use importance sampling with auxiliary distribution:\n",
|
" Now let's repeat this experiment with the second function, but this time use importance sampling with auxiliary distribution:\n",
|
||||||
"\n",
|
"\n",
|
||||||
" \\begin{equation}\n",
|
" \\begin{equation}\n",
|
||||||
" q(y)=\\mbox{Norm}_{y}[3,1]\n",
|
" q(y)=\\text{Norm}_{y}[3,1]\n",
|
||||||
" \\end{equation}\n"
|
" \\end{equation}\n"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "_wuF-NoQu1--"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "jPm0AVYVIDnn"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def q_y(y):\n",
|
"def q_y(y):\n",
|
||||||
" return (1/np.sqrt(2*np.pi)) * np.exp(-0.5 * (y-3) * (y-3))\n",
|
" return (1/np.sqrt(2*np.pi)) * np.exp(-0.5 * (y-3) * (y-3))\n",
|
||||||
@@ -395,15 +394,15 @@
|
|||||||
" expectation = 0\n",
|
" expectation = 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return expectation"
|
" return expectation"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "jPm0AVYVIDnn"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "No2ByVvOM2yQ"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Set the seed so the random numbers are all the same\n",
|
"# Set the seed so the random numbers are all the same\n",
|
||||||
"np.random.seed(0)\n",
|
"np.random.seed(0)\n",
|
||||||
@@ -412,15 +411,15 @@
|
|||||||
"n_samples = 100000000\n",
|
"n_samples = 100000000\n",
|
||||||
"expected_f2= compute_expectation2b(n_samples)\n",
|
"expected_f2= compute_expectation2b(n_samples)\n",
|
||||||
"print(\"Your value: \", expected_f2,\", True value: 0.43163734204459125 \")"
|
"print(\"Your value: \", expected_f2,\", True value: 0.43163734204459125 \")"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "No2ByVvOM2yQ"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "6v8Jc7z4M3Mk"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def compute_mean_variance2b(n_sample):\n",
|
"def compute_mean_variance2b(n_sample):\n",
|
||||||
" n_estimate = 10000\n",
|
" n_estimate = 10000\n",
|
||||||
@@ -435,15 +434,15 @@
|
|||||||
"for i in range(len(n_sample_all)):\n",
|
"for i in range(len(n_sample_all)):\n",
|
||||||
" print(\"Computing variance for expectation with %d samples\"%(n_sample_all[i]))\n",
|
" print(\"Computing variance for expectation with %d samples\"%(n_sample_all[i]))\n",
|
||||||
" mean_all2b[i], variance_all2b[i] = compute_mean_variance2b(n_sample_all[i])"
|
" mean_all2b[i], variance_all2b[i] = compute_mean_variance2b(n_sample_all[i])"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "6v8Jc7z4M3Mk"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "C0beD4sNNM3L"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"fig,ax1 = plt.subplots()\n",
|
"fig,ax1 = plt.subplots()\n",
|
||||||
"ax1.semilogx(n_sample_all, mean_all,'r-',label='mean estimate')\n",
|
"ax1.semilogx(n_sample_all, mean_all,'r-',label='mean estimate')\n",
|
||||||
@@ -476,21 +475,33 @@
|
|||||||
"ax2.set_title(\"Second function with importance sampling\")\n",
|
"ax2.set_title(\"Second function with importance sampling\")\n",
|
||||||
"ax2.legend()\n",
|
"ax2.legend()\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"id": "C0beD4sNNM3L"
|
|
||||||
},
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"You can see that the importance sampling technique has reduced the amount of variance for any given number of samples."
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "y8rgge9MNiOc"
|
"id": "y8rgge9MNiOc"
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"You can see that the importance sampling technique has reduced the amount of variance for any given number of samples."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"authorship_tag": "ABX9TyNecz9/CDOggPSmy1LjT/Dv",
|
||||||
|
"include_colab_link": true,
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user