{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyMMms8JzNSVjQ3tqcXxUWCD", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# **Notebook 18.1: Diffusion Encoder**\n", "\n", "This notebook investigates the diffusion encoder as described in section 18.2 of the book.\n", "\n", "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "\n", "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." ], "metadata": { "id": "t9vk9Elugvmi" } }, { "cell_type": "code", "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib.colors import ListedColormap\n", "from operator import itemgetter" ], "metadata": { "id": "OLComQyvCIJ7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#Create pretty colormap as in book\n", "my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", "my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n", "r = np.floor(my_colormap_vals_dec/(256*256))\n", "g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n", "b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", "my_colormap_vals = np.vstack((r,g,b)).transpose()/255.0\n", "my_colormap = ListedColormap(my_colormap_vals)" ], "metadata": { "id": "4PM8bf6lO0VE" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Probability distribution for normal\n", "def norm_pdf(x, mu, sigma):\n", " return np.exp(-0.5 * (x-mu) * (x-mu) / (sigma * sigma)) / np.sqrt(2*np.pi*sigma*sigma)" ], "metadata": { "id": "ONGRaQscfIOo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# True distribution is a mixture of four Gaussians\n", "class TrueDataDistribution:\n", " # Constructor initializes parameters\n", " def __init__(self):\n", " self.mu = [1.5, -0.216, 0.45, -1.875]\n", " self.sigma = [0.3, 0.15, 0.525, 0.075]\n", " self.w = [0.2, 0.3, 0.35, 0.15]\n", "\n", " # Return PDF\n", " def pdf(self, x):\n", " return(self.w[0] *norm_pdf(x,self.mu[0],self.sigma[0]) + self.w[1] *norm_pdf(x,self.mu[1],self.sigma[1]) + self.w[2] *norm_pdf(x,self.mu[2],self.sigma[2]) + self.w[3] *norm_pdf(x,self.mu[3],self.sigma[3]))\n", "\n", " # Draw samples\n", " def sample(self, n):\n", " hidden = np.random.choice(4, n, p=self.w)\n", " epsilon = np.random.normal(size=(n))\n", " mu_list = list(itemgetter(*hidden)(self.mu))\n", " sigma_list = list(itemgetter(*hidden)(self.sigma))\n", " return mu_list + sigma_list * epsilon" ], "metadata": { "id": "gZvG0MKhfY8Y" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Define ground truth probability distribution that we will model\n", "true_dist = TrueDataDistribution()\n", "# Let's visualize this\n", "x_vals = np.arange(-3,3,0.01)\n", "pr_x_true = true_dist.pdf(x_vals)\n", "fig,ax = plt.subplots()\n", "ax.plot(x_vals, pr_x_true, 'r-')\n", "ax.set_xlabel(\"$x$\")\n", "ax.set_ylabel(\"$Pr(x)$\")\n", "ax.set_ylim(0,1.0)\n", "ax.set_xlim(-3,3)\n", "plt.show()" ], "metadata": { "id": "qXmej3TUuQyp" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's first implement the forward process" ], "metadata": { "id": "XHdtfRP47YLy" } }, { "cell_type": "code", "source": [ "# Do one step of diffusion (equation 18.1)\n", "def diffuse_one_step(z_t_minus_1, beta_t):\n", " # TODO -- Implement this function\n", " # Use np.random.normal to generate the samples epsilon\n", " # Replace this line\n", " z_t = np.zeros_like(z_t_minus_1)\n", "\n", " return z_t" ], "metadata": { "id": "hkApJ2VJlQuk" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now let's run the diffusion process for a whole bunch of samples" ], "metadata": { "id": "ECAUfHNi9NVW" } }, { "cell_type": "code", "source": [ "# Generate some samples\n", "n_sample = 10000\n", "np.random.seed(6)\n", "x = true_dist.sample(n_sample)\n", "\n", "# Number of time steps\n", "T = 100\n", "# Noise schedule has same value at every time step\n", "beta = 0.01511\n", "\n", "# We'll store the diffused samples in an array\n", "samples = np.zeros((T+1, n_sample))\n", "samples[0,:] = x\n", "\n", "for t in range(T):\n", " samples[t+1,:] = diffuse_one_step(samples[t,:], beta)" ], "metadata": { "id": "M-TY5w9Q8LYW" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's, plot the evolution of a few paths as in figure 18.2" ], "metadata": { "id": "jYrAW6tN-gJ4" } }, { "cell_type": "code", "source": [ "fig, ax = plt.subplots()\n", "t_vals = np.arange(0,101,1)\n", "ax.plot(samples[:,0],t_vals,'r-')\n", "ax.plot(samples[:,1],t_vals,'g-')\n", "ax.plot(samples[:,2],t_vals,'b-')\n", "ax.plot(samples[:,3],t_vals,'c-')\n", "ax.plot(samples[:,4],t_vals,'m-')\n", "ax.set_xlim([-3,3])\n", "ax.set_ylim([101, 0])\n", "ax.set_xlabel('value')\n", "ax.set_ylabel('z_{t}')\n", "plt.show()" ], "metadata": { "id": "4XU6CDZC_kFo" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Notice that the samples have a tendency to move toward the center. Now let's look at the histogram of the samples at each stage" ], "metadata": { "id": "SGTYGGevAktz" } }, { "cell_type": "code", "source": [ "def draw_hist(z_t,title=''):\n", " fig, ax = plt.subplots()\n", " fig.set_size_inches(8,2.5)\n", " plt.hist(z_t , bins=np.arange(-3,3, 0.1), density = True)\n", " ax.set_xlim([-3,3])\n", " ax.set_ylim([0,1.0])\n", " ax.set_title(title)\n", " plt.show()" ], "metadata": { "id": "bn5E5NzL-evM" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "draw_hist(samples[0,:],'Original data')\n", "draw_hist(samples[5,:],'Time step 5')\n", "draw_hist(samples[10,:],'Time step 10')\n", "draw_hist(samples[20,:],'Time step 20')\n", "draw_hist(samples[40,:],'Time step 40')\n", "draw_hist(samples[80,:],'Time step 80')\n", "draw_hist(samples[100,:],'Time step 100')" ], "metadata": { "id": "pn_XD-EhBlwk" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "You can clearly see that as the diffusion process continues, the data becomes more Gaussian." ], "metadata": { "id": "skuLfGl5Czf4" } }, { "cell_type": "markdown", "source": [ "Now let's investigate the diffusion kernel as in figure 18.3 of the book.\n" ], "metadata": { "id": "s37CBSzzK7wh" } }, { "cell_type": "code", "source": [ "def diffusion_kernel(x, t, beta):\n", " # TODO -- write this function\n", " # Replace this line:\n", " dk_mean = 0.0 ; dk_std = 1.0\n", "\n", " return dk_mean, dk_std" ], "metadata": { "id": "vL62Iym0LEtY" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def draw_prob_dist(x_plot_vals, prob_dist, title=''):\n", " fig, ax = plt.subplots()\n", " fig.set_size_inches(8,2.5)\n", " ax.plot(x_plot_vals, prob_dist, 'b-')\n", " ax.set_xlim([-3,3])\n", " ax.set_ylim([0,1.0])\n", " ax.set_title(title)\n", " plt.show()\n", "\n", "def compute_and_plot_diffusion_kernels(x, T, beta, my_colormap):\n", " x_plot_vals = np.arange(-3,3,0.01)\n", " diffusion_kernels = np.zeros((T+1,len(x_plot_vals)))\n", " dk_mean_all = np.ones((T+1,1))*x\n", " dk_std_all = np.zeros((T+1,1))\n", " for t in range(T):\n", " dk_mean_all[t+1], dk_std_all[t+1] = diffusion_kernel(x,t+1,beta)\n", " diffusion_kernels[t+1,:] = norm_pdf(x_plot_vals, dk_mean_all[t+1], dk_std_all[t+1])\n", "\n", " samples = np.ones((T+1, 5))\n", " samples[0,:] = x\n", "\n", " for t in range(T):\n", " samples[t+1,:] = diffuse_one_step(samples[t,:], beta)\n", "\n", " fig, ax = plt.subplots()\n", " fig.set_size_inches(6,6)\n", "\n", " # Plot the image containing all the kernels\n", " plt.imshow(diffusion_kernels, cmap=my_colormap, interpolation='nearest')\n", "\n", " # Plot +/- 2 standard deviations\n", " ax.plot((dk_mean_all -2 * dk_std_all)/0.01 + len(x_plot_vals)/2, np.arange(0,101,1),'y--')\n", " ax.plot((dk_mean_all +2 * dk_std_all)/0.01 + len(x_plot_vals)/2, np.arange(0,101,1),'y--')\n", "\n", " # Plot a few trajectories\n", " ax.plot(samples[:,0]/0.01 + + len(x_plot_vals)/2, np.arange(0,101,1), 'r-')\n", " ax.plot(samples[:,1]/0.01 + + len(x_plot_vals)/2, np.arange(0,101,1), 'g-')\n", " ax.plot(samples[:,2]/0.01 + + len(x_plot_vals)/2, np.arange(0,101,1), 'b-')\n", " ax.plot(samples[:,3]/0.01 + + len(x_plot_vals)/2, np.arange(0,101,1), 'c-')\n", " ax.plot(samples[:,4]/0.01 + + len(x_plot_vals)/2, np.arange(0,101,1), 'm-')\n", "\n", " # Tidy up and plot\n", " ax.set_ylabel(\"$Pr(z_{t}|x)$\")\n", " ax.get_xaxis().set_visible(False)\n", " ax.set_xlim([0,601])\n", " ax.set_aspect(601/T)\n", " plt.show()\n", "\n", "\n", " draw_prob_dist(x_plot_vals, diffusion_kernels[20,:],'$q(z_{20}|x)$')\n", " draw_prob_dist(x_plot_vals, diffusion_kernels[40,:],'$q(z_{40}|x)$')\n", " draw_prob_dist(x_plot_vals, diffusion_kernels[80,:],'$q(z_{80}|x)$')" ], "metadata": { "id": "KtP1KF8wMh8o" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "x = -2\n", "compute_and_plot_diffusion_kernels(x, T, beta, my_colormap)" ], "metadata": { "id": "g8TcI5wtRQsx" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "TODO -- Run this for different version of $x$ and check that you understand how the graphs change" ], "metadata": { "id": "-RuN2lR28-hK" } }, { "cell_type": "markdown", "source": [ "Finally, let's estimate the marginal distributions empirically and visualize them as in figure 18.4 of the book. This is only tractable because the data is in one dimension and we know the original distribution.\n", "\n", "The marginal distribution at time t is the sum of the diffusion kernels for each position x, weighted by the probability of seeing that value of x in the true distribution." ], "metadata": { "id": "n-x6Whz2J_zy" } }, { "cell_type": "code", "source": [ "def diffusion_marginal(x_plot_vals, pr_x_true, t, beta):\n", " # If time is zero then marginal is just original distribution\n", " if t == 0:\n", " return pr_x_true\n", "\n", " # The thing we are computing\n", " marginal_at_time_t = np.zeros_like(pr_x_true);\n", "\n", "\n", " # TODO Write ths function\n", " # 1. For each x (value in x_plot_vals):\n", " # 2. Compute the mean and variance of the diffusion kernel at time t\n", " # 3. Compute pdf of this Gaussian at every x_plot_val\n", " # 4. Weight Gaussian by probability at position x and by 0.01 to componensate for bin size\n", " # 5. Accumulate weighted Gaussian in marginal at time t.\n", " # 6. Multiply result by 0.01 to compensate for bin size\n", " # Replace this line:\n", " marginal_at_time_t = marginal_at_time_t\n", "\n", "\n", "\n", " return marginal_at_time_t" ], "metadata": { "id": "YzN5duYpg7C-" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "x_plot_vals = np.arange(-3,3,0.01)\n", "marginal_distributions = np.zeros((T+1,len(x_plot_vals)))\n", "\n", "for t in range(T+1):\n", " marginal_distributions[t,:] = diffusion_marginal(x_plot_vals, pr_x_true , t, beta)\n", "\n", "fig, ax = plt.subplots()\n", "fig.set_size_inches(6,6)\n", "\n", "# Plot the image containing all the kernels\n", "plt.imshow(marginal_distributions, cmap=my_colormap, interpolation='nearest')\n", "\n", "# Tidy up and plot\n", "ax.set_ylabel(\"$Pr(z_{t})$\")\n", "ax.get_xaxis().set_visible(False)\n", "ax.set_xlim([0,601])\n", "ax.set_aspect(601/T)\n", "plt.show()\n", "\n", "\n", "draw_prob_dist(x_plot_vals, marginal_distributions[0,:],'$q(z_{0})$')\n", "draw_prob_dist(x_plot_vals, marginal_distributions[20,:],'$q(z_{20})$')\n", "draw_prob_dist(x_plot_vals, marginal_distributions[60,:],'$q(z_{60})$')" ], "metadata": { "id": "OgEU9sxjRaeO" }, "execution_count": null, "outputs": [] } ] }