diff --git a/Notebooks/Chap17/17_2_Reparameterization_Trick.ipynb b/Notebooks/Chap17/17_2_Reparameterization_Trick.ipynb new file mode 100644 index 0000000..21b48c6 --- /dev/null +++ b/Notebooks/Chap17/17_2_Reparameterization_Trick.ipynb @@ -0,0 +1,423 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyOxO2/0DTH4n4zhC97qbagY", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Notebook 17.2: Reparameterization trick**\n", + "\n", + "This notebook investigates the reparameterization trick as described in section 17.7 of the book.\n", + "\n", + "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", + "\n", + "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." + ], + "metadata": { + "id": "t9vk9Elugvmi" + } + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ], + "metadata": { + "id": "OLComQyvCIJ7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The reparameterization trick computes the derivative of an expectation of a function $\\mbox{f}[x]$:\n", + "\n", + "\\begin{equation}\n", + "\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr],\n", + "\\end{equation}\n", + "\n", + "with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over.\n", + "\n", + "Let's consider a simple concrete example, where:\n", + "\n", + "\\begin{equation}\n", + "Pr(x|\\phi) = \\mbox{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\mbox{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr]\n", + "\\end{equation}\n", + "\n", + "and\n", + "\n", + "\\begin{equation}\n", + "\\mbox{f}[x] = x^2+\\sin[x]\n", + "\\end{equation}" + ], + "metadata": { + "id": "paLz5RukZP1J" + } + }, + { + "cell_type": "code", + "source": [ + "# Let's approximate this expecctation for a particular value of phi\n", + "def compute_expectation(phi, n_samples):\n", + " # TODO complete this function\n", + " # 1. Compute the mean of the normal distribution, mu\n", + " # 2. Compute the standard deviation of the normal distribution, sigma\n", + " # 3. Draw n_samples samples using np.random.normal(mu, sigma, size=(n_samples, 1))\n", + " # 4. Compute f[x] for each of these samples\n", + " # 4. Approximate the expectation by taking the average of the values of f[x]\n", + " # Replace this line\n", + " expected_f_given_phi = 0\n", + "\n", + "\n", + " return expected_f_given_phi" + ], + "metadata": { + "id": "FdEbMnDBY0i9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Set the seed so the random numbers are all the same\n", + "np.random.seed(0)\n", + "\n", + "# Compute the expectation for two values of phi\n", + "phi1 = 0.5\n", + "n_samples = 10000000\n", + "expected_f_given_phi1 = compute_expectation(phi1, n_samples)\n", + "print(\"Your value: \", expected_f_given_phi1, \", True value: 2.7650801613563116\")\n", + "\n", + "phi2 = -0.1\n", + "n_samples = 10000000\n", + "expected_f_given_phi2 = compute_expectation(phi2, n_samples)\n", + "print(\"Your value: \", expected_f_given_phi2, \", True value: 0.8176793102849222\")" + ], + "metadata": { + "id": "FTh7LJ0llNJZ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Le't plot this expectation as a function of phi" + ], + "metadata": { + "id": "r5Hl2QkimWx9" + } + }, + { + "cell_type": "code", + "source": [ + "phi_vals = np.arange(-1.5,1.5, 0.05)\n", + "expected_vals = np.zeros_like(phi_vals)\n", + "n_samples = 1000000\n", + "for i in range(len(phi_vals)):\n", + " expected_vals[i] = compute_expectation(phi_vals[i], n_samples)\n", + "\n", + "fig,ax = plt.subplots()\n", + "ax.plot(phi_vals, expected_vals,'r-')\n", + "ax.set_xlabel('Parameter $\\phi$')\n", + "ax.set_ylabel('$\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n", + "plt.show()" + ], + "metadata": { + "id": "05XxVLJxmkER" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "It's this curve that we want to find the derivative of (so for example, we could run gradient descent and find the minimum.\n", + "\n", + "This is tricky though -- if you look at the computation that you performed, then there is a sampling step in the procedure (step 3). How do we compute the derivative of this?\n", + "\n", + "The answer is the reparameterization trick. We note that:\n", + "\n", + "\\begin{equation}\n", + "\\mbox{Norm}_{x}\\Bigl[\\mu, \\sigma^2\\Bigr]=\\mbox{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\sigma + \\mu\n", + "\\end{equation}\n", + "\n", + "and so:\n", + "\n", + "\\begin{equation}\n", + "\\mbox{Norm}_{x}\\Bigl[\\phi^3,(\\exp[\\phi])^2\\Bigr] = \\mbox{Norm}_{x}\\Bigl[0, 1\\Bigr] \\times \\exp[\\phi]+ \\phi^3\n", + "\\end{equation}\n", + "\n", + "So, if we draw a sample $\\epsilon^*$ from $\\mbox{Norm}_{\\epsilon}[0, 1]$, then we can compute a sample $x^*$ as:\n", + "\n", + "\\begin{eqnarray*}\n", + "x^* &=& \\epsilon^* \\times \\sigma + \\mu \\\\\n", + "&=& \\epsilon^* \\times \\exp[\\phi]+ \\phi^3\n", + "\\end{eqnarray*}" + ], + "metadata": { + "id": "zTCykVeWqj_O" + } + }, + { + "cell_type": "code", + "source": [ + "def compute_df_dx_star(x_star):\n", + " # TODO Compute this derivative (function defined at the top)\n", + " # Replace this line:\n", + " deriv = 0;\n", + "\n", + "\n", + "\n", + " return deriv\n", + "\n", + "def compute_dx_star_dphi(epsilon_star, phi):\n", + " # TODO Compute this derivative\n", + " # Replace this line:\n", + " deriv = 0;\n", + "\n", + "\n", + "\n", + " return deriv\n", + "\n", + "def compute_derivative_of_expectation(phi, n_samples):\n", + " # Generate the random values of epsilon\n", + " epsilon_star= np.random.normal(size=(n_samples,1))\n", + " # TODO -- write\n", + " # 1. Compute dx*/dphi using the function defined above\n", + " # 2. Compute x*\n", + " # 3. Compute df/dx* using the function you wrote above\n", + " # 4. Compute df/dphi = df/x* * dx*dphi\n", + " # 5. Average the samples of df/dphi to get the expectation.\n", + " # Replace this line:\n", + " df_dphi = 0\n", + "\n", + "\n", + "\n", + " return df_dphi" + ], + "metadata": { + "id": "w13HVpi9q8nF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Set the seed so the random numbers are all the same\n", + "np.random.seed(0)\n", + "\n", + "# Compute the expectation for two values of phi\n", + "phi1 = 0.5\n", + "n_samples = 10000000\n", + "\n", + "deriv = compute_derivative_of_expectation(phi1, n_samples)\n", + "print(\"Your value: \", deriv, \", True value: 5.726338035051403\")" + ], + "metadata": { + "id": "ntQT4An79kAl" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "phi_vals = np.arange(-1.5,1.5, 0.05)\n", + "deriv_vals = np.zeros_like(phi_vals)\n", + "n_samples = 1000000\n", + "for i in range(len(phi_vals)):\n", + " deriv_vals[i] = compute_derivative_of_expectation(phi_vals[i], n_samples)\n", + "\n", + "fig,ax = plt.subplots()\n", + "ax.plot(phi_vals, deriv_vals,'r-')\n", + "ax.set_xlabel('Parameter $\\phi$')\n", + "ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n", + "plt.show()" + ], + "metadata": { + "id": "t0Jqd_IN_lMU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "This should look plausibly like the derivative of the function we plotted above!" + ], + "metadata": { + "id": "ASu4yKSwAEYI" + } + }, + { + "cell_type": "markdown", + "source": [ + "The reparameterization trick computes the derivative of an expectation of a function $\\mbox{f}[x]$:\n", + "\n", + "\\begin{equation}\n", + "\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr],\n", + "\\end{equation}\n", + "\n", + "with respect to the parameters $\\boldsymbol\\phi$ of the distribution $Pr(x|\\boldsymbol\\phi)$ that the expectation is over. This derivative can also be computed as:\n", + "\n", + "\\begin{eqnarray}\n", + "\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\bigl[\\mbox{f}[x]\\bigr] &=& \\mathbb{E}_{Pr(x|\\boldsymbol\\phi)}\\left[\\mbox{f}[x]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x|\\boldsymbol\\phi)\\bigr]\\right]\\nonumber \\\\\n", + "&\\approx & \\frac{1}{I}\\sum_{i=1}^{I}\\mbox{f}[x_i]\\frac{\\partial}{\\partial \\boldsymbol\\phi} \\log\\bigl[ Pr(x_i|\\boldsymbol\\phi)\\bigr].\n", + "\\end{eqnarray}\n", + "\n", + "This method is known as the REINFORCE algorithm or score function estimator. Problem 17.5 asks you to prove this relation. Let's use this method to compute the gradient and compare.\n", + "\n", + "Recall that the expression for a univariate Gaussian is:\n", + "\n", + "\\begin{equation}\n", + " Pr(x|\\mu,\\sigma^2) = \\frac{1}{\\sqrt{2\\pi\\sigma^{2}}}\\exp\\left[-\\frac{(x-\\mu)^{2}}{2\\sigma^{2}}\\right].\n", + "\\end{equation}\n" + ], + "metadata": { + "id": "xoFR1wifc8-b" + } + }, + { + "cell_type": "code", + "source": [ + "def d_log_pr_x_given_phi(x,phi):\n", + " # TODO -- fill in this function\n", + " # Compute the derivative of log[Pr(x|phi)]\n", + " # Replace this line:\n", + " deriv =0;\n", + "\n", + "\n", + " return deriv\n", + "\n", + "\n", + "def compute_derivative_of_expectation_score_function(phi, n_samples):\n", + " # TODO -- Compute this function\n", + " # 1. Calculate mu from phi\n", + " # 2. Calculate sigma from phi\n", + " # 3. Generate n_sample random samples of x using np.random.normal\n", + " # 4. Calculate f[x] for all of the samples\n", + " # 5. Multiply f[x] by d_log_pr_x_given_phi\n", + " # 6. Take the average of the samples\n", + " # Replace this line:\n", + " deriv = 0\n", + "\n", + "\n", + "\n", + " return deriv" + ], + "metadata": { + "id": "4TUaxiWvASla" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Set the seed so the random numbers are all the same\n", + "np.random.seed(0)\n", + "\n", + "# Compute the expectation for two values of phi\n", + "phi1 = 0.5\n", + "n_samples = 100000000\n", + "\n", + "deriv = compute_derivative_of_expectation_score_function(phi1, n_samples)\n", + "print(\"Your value: \", deriv, \", True value: 5.724609927313369\")" + ], + "metadata": { + "id": "0RSN32Rna_C_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "phi_vals = np.arange(-1.5,1.5, 0.05)\n", + "deriv_vals = np.zeros_like(phi_vals)\n", + "n_samples = 1000000\n", + "for i in range(len(phi_vals)):\n", + " deriv_vals[i] = compute_derivative_of_expectation_score_function(phi_vals[i], n_samples)\n", + "\n", + "fig,ax = plt.subplots()\n", + "ax.plot(phi_vals, deriv_vals,'r-')\n", + "ax.set_xlabel('Parameter $\\phi$')\n", + "ax.set_ylabel('$\\partial/\\partial\\phi\\mathbb{E}_{Pr(x|\\phi)}[f[x]]$')\n", + "plt.show()" + ], + "metadata": { + "id": "EM_i5zoyElHR" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "This should look the same as the derivative that we computed with the reparameterization trick. So, is there any advantage to one way or the other? Let's compare the variances of the estimates\n" + ], + "metadata": { + "id": "1TWBiUC7bQSw" + } + }, + { + "cell_type": "code", + "source": [ + "n_estimate = 100\n", + "n_sample = 1000\n", + "phi = 0.3\n", + "reparam_estimates = np.zeros((n_estimate,1))\n", + "score_function_estimates = np.zeros((n_estimate,1))\n", + "for i in range(n_estimate):\n", + " reparam_estimates[i]= compute_derivative_of_expectation(phi, n_samples)\n", + " score_function_estimates[i] = compute_derivative_of_expectation_score_function(phi, n_samples)\n", + "\n", + "print(\"Variance of reparameterization estimator\", np.var(reparam_estimates))\n", + "print(\"Variance of score function estimator\", np.var(score_function_estimates))" + ], + "metadata": { + "id": "vV_Jx5bCbQGs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The variance of the reparameterization estimator should be quite a bit lower than the score function estimator which is why it is preferred in this situation." + ], + "metadata": { + "id": "d-0tntSYdKPR" + } + } + ] +} \ No newline at end of file