Add files via upload

This commit is contained in:
udlbook
2024-01-02 13:09:22 -05:00
committed by GitHub
parent db836826f6
commit c19e2411c5
3 changed files with 471 additions and 442 deletions

View File

@@ -1,33 +1,22 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOxO2/0DTH4n4zhC97qbagY",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_2_Reparameterization_Trick.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "t9vk9Elugvmi"
},
"source": [
"# **Notebook 17.2: Reparameterization trick**\n",
"\n",
@@ -36,30 +25,31 @@
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
],
"metadata": {
"id": "t9vk9Elugvmi"
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
],
"execution_count": null,
"metadata": {
"id": "OLComQyvCIJ7"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "paLz5RukZP1J"
},
"source": [
"The reparameterization trick computes the derivative of an expectation of a function $\\mbox{f}[x]$:\n",
"The reparameterization trick computes the derivative of an expectation of a function $\\text{f}[x]$:\n",
"\n",
"\\begin{equation}\n",
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr],\n",
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\text{f}[x]\\bigr],\n",
"\\end{equation}\n",
"\n",
"with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over.\n",
@@ -67,21 +57,23 @@
"Let's consider a simple concrete example, where:\n",
"\n",
"\\begin{equation}\n",
"Pr(x|\\phi) = \\mbox{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\mbox{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr]\n",
"Pr(x|\\phi) = \\text{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\text{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr]\n",
"\\end{equation}\n",
"\n",
"and\n",
"\n",
"\\begin{equation}\n",
"\\mbox{f}[x] = x^2+\\sin[x]\n",
"\\text{f}[x] = x^2+\\sin[x]\n",
"\\end{equation}"
],
"metadata": {
"id": "paLz5RukZP1J"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FdEbMnDBY0i9"
},
"outputs": [],
"source": [
"# Let's approximate this expectation for a particular value of phi\n",
"def compute_expectation(phi, n_samples):\n",
@@ -96,15 +88,15 @@
"\n",
"\n",
" return expected_f_given_phi"
],
"metadata": {
"id": "FdEbMnDBY0i9"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FTh7LJ0llNJZ"
},
"outputs": [],
"source": [
"# Set the seed so the random numbers are all the same\n",
"np.random.seed(0)\n",
@@ -119,24 +111,25 @@
"n_samples = 10000000\n",
"expected_f_given_phi2 = compute_expectation(phi2, n_samples)\n",
"print(\"Your value: \", expected_f_given_phi2, \", True value: 0.8176793102849222\")"
],
"metadata": {
"id": "FTh7LJ0llNJZ"
},
"execution_count": null,
"outputs": []
]
},
{
"attachments": {},
"cell_type": "markdown",
"source": [
"Le't plot this expectation as a function of phi"
],
"metadata": {
"id": "r5Hl2QkimWx9"
}
},
"source": [
"Le't plot this expectation as a function of phi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "05XxVLJxmkER"
},
"outputs": [],
"source": [
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
"expected_vals = np.zeros_like(phi_vals)\n",
@@ -149,15 +142,14 @@
"ax.set_xlabel('Parameter $\\phi$')\n",
"ax.set_ylabel('$\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
"plt.show()"
],
"metadata": {
"id": "05XxVLJxmkER"
},
"execution_count": null,
"outputs": []
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "zTCykVeWqj_O"
},
"source": [
"It's this curve that we want to find the derivative of (so for example, we could run gradient descent and find the minimum.\n",
"\n",
@@ -166,28 +158,30 @@
"The answer is the reparameterization trick. We note that:\n",
"\n",
"\\begin{equation}\n",
"\\mbox{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\mbox{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\sigma + \\mu\n",
"\\text{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\text{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\sigma + \\mu\n",
"\\end{equation}\n",
"\n",
"and so:\n",
"\n",
"\\begin{equation}\n",
"\\mbox{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr] = \\mbox{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\exp[\\phi]+ \\phi^3\n",
"\\text{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr] = \\text{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\exp[\\phi]+ \\phi^3\n",
"\\end{equation}\n",
"\n",
"So, if we draw a sample $\\epsilon^*$ from $\\mbox{Norm}_{\\epsilon}[0, 1]$, then we can compute a sample $x^*$ as:\n",
"So, if we draw a sample $\\epsilon^*$ from $\\text{Norm}_{\\epsilon}[0, 1]$, then we can compute a sample $x^*$ as:\n",
"\n",
"\\begin{eqnarray*}\n",
"\\begin{align}\n",
"x^* &=& \\epsilon^* \\times \\sigma + \\mu \\\\\n",
"&=& \\epsilon^* \\times \\exp[\\phi]+ \\phi^3\n",
"\\end{eqnarray*}"
],
"metadata": {
"id": "zTCykVeWqj_O"
}
"\\end{align}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w13HVpi9q8nF"
},
"outputs": [],
"source": [
"def compute_df_dx_star(x_star):\n",
" # TODO Compute this derivative (function defined at the top)\n",
@@ -222,15 +216,15 @@
"\n",
"\n",
" return df_dphi"
],
"metadata": {
"id": "w13HVpi9q8nF"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ntQT4An79kAl"
},
"outputs": [],
"source": [
"# Set the seed so the random numbers are all the same\n",
"np.random.seed(0)\n",
@@ -241,15 +235,15 @@
"\n",
"deriv = compute_derivative_of_expectation(phi1, n_samples)\n",
"print(\"Your value: \", deriv, \", True value: 5.726338035051403\")"
],
"metadata": {
"id": "ntQT4An79kAl"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t0Jqd_IN_lMU"
},
"outputs": [],
"source": [
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
"deriv_vals = np.zeros_like(phi_vals)\n",
@@ -262,37 +256,37 @@
"ax.set_xlabel('Parameter $\\phi$')\n",
"ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
"plt.show()"
],
"metadata": {
"id": "t0Jqd_IN_lMU"
},
"execution_count": null,
"outputs": []
]
},
{
"attachments": {},
"cell_type": "markdown",
"source": [
"This should look plausibly like the derivative of the function we plotted above!"
],
"metadata": {
"id": "ASu4yKSwAEYI"
}
},
"source": [
"This should look plausibly like the derivative of the function we plotted above!"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "xoFR1wifc8-b"
},
"source": [
"The reparameterization trick computes the derivative of an expectation of a function $\\mbox{f}[x]$:\n",
"The reparameterization trick computes the derivative of an expectation of a function $\\text{f}[x]$:\n",
"\n",
"\\begin{equation}\n",
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr],\n",
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\text{f}[x]\\bigr],\n",
"\\end{equation}\n",
"\n",
"with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over. This derivative can also be computed as:\n",
"\n",
"\\begin{eqnarray}\n",
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr] &=& \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\left[\\mbox{f}[x]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x|\\boldsymbol\\phi)\\bigr]\\right]\\nonumber \\\\\n",
"&\\approx & \\frac{1}{I}\\sum_{i=1}^{I}\\mbox{f}[x_i]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x_i|\\boldsymbol\\phi)\\bigr].\n",
"\\end{eqnarray}\n",
"\\begin{align}\n",
"\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\text{f}[x]\\bigr] &=& \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\left[\\text{f}[x]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x|\\boldsymbol\\phi)\\bigr]\\right]\\nonumber \\\\\n",
"&\\approx & \\frac{1}{I}\\sum_{i=1}^{I}\\text{f}[x_i]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x_i|\\boldsymbol\\phi)\\bigr].\n",
"\\end{align}\n",
"\n",
"This method is known as the REINFORCE algorithm or score function estimator. Problem 17.5 asks you to prove this relation. Let's use this method to compute the gradient and compare.\n",
"\n",
@@ -301,13 +295,15 @@
"\\begin{equation}\n",
" Pr(x|\\mu,\\sigma^2) = \\frac{1}{\\sqrt{2\\pi\\sigma^{2}}}\\exp\\left[-\\frac{(x-\\mu)^{2}}{2\\sigma^{2}}\\right].\n",
"\\end{equation}\n"
],
"metadata": {
"id": "xoFR1wifc8-b"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4TUaxiWvASla"
},
"outputs": [],
"source": [
"def d_log_pr_x_given_phi(x,phi):\n",
" # TODO -- fill in this function\n",
@@ -333,15 +329,15 @@
"\n",
"\n",
" return deriv"
],
"metadata": {
"id": "4TUaxiWvASla"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0RSN32Rna_C_"
},
"outputs": [],
"source": [
"# Set the seed so the random numbers are all the same\n",
"np.random.seed(0)\n",
@@ -352,15 +348,15 @@
"\n",
"deriv = compute_derivative_of_expectation_score_function(phi1, n_samples)\n",
"print(\"Your value: \", deriv, \", True value: 5.724609927313369\")"
],
"metadata": {
"id": "0RSN32Rna_C_"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EM_i5zoyElHR"
},
"outputs": [],
"source": [
"phi_vals = np.arange(-1.5,1.5, 0.05)\n",
"deriv_vals = np.zeros_like(phi_vals)\n",
@@ -373,24 +369,25 @@
"ax.set_xlabel('Parameter $\\phi$')\n",
"ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n",
"plt.show()"
],
"metadata": {
"id": "EM_i5zoyElHR"
},
"execution_count": null,
"outputs": []
]
},
{
"attachments": {},
"cell_type": "markdown",
"source": [
"This should look the same as the derivative that we computed with the reparameterization trick. So, is there any advantage to one way or the other? Let's compare the variances of the estimates\n"
],
"metadata": {
"id": "1TWBiUC7bQSw"
}
},
"source": [
"This should look the same as the derivative that we computed with the reparameterization trick. So, is there any advantage to one way or the other? Let's compare the variances of the estimates\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vV_Jx5bCbQGs"
},
"outputs": [],
"source": [
"n_estimate = 100\n",
"n_sample = 1000\n",
@@ -403,21 +400,33 @@
"\n",
"print(\"Variance of reparameterization estimator\", np.var(reparam_estimates))\n",
"print(\"Variance of score function estimator\", np.var(score_function_estimates))"
],
"metadata": {
"id": "vV_Jx5bCbQGs"
},
"execution_count": null,
"outputs": []
]
},
{
"attachments": {},
"cell_type": "markdown",
"source": [
"The variance of the reparameterization estimator should be quite a bit lower than the score function estimator which is why it is preferred in this situation."
],
"metadata": {
"id": "d-0tntSYdKPR"
}
},
"source": [
"The variance of the reparameterization estimator should be quite a bit lower than the score function estimator which is why it is preferred in this situation."
]
}
]
}
],
"metadata": {
"colab": {
"authorship_tag": "ABX9TyOxO2/0DTH4n4zhC97qbagY",
"include_colab_link": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}