diff --git a/Notebooks/Chap18/18_2_1D_Diffusion_Model.ipynb b/Notebooks/Chap18/18_2_1D_Diffusion_Model.ipynb index a706a21..42b7c65 100644 --- a/Notebooks/Chap18/18_2_1D_Diffusion_Model.ipynb +++ b/Notebooks/Chap18/18_2_1D_Diffusion_Model.ipynb @@ -1,33 +1,22 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "authorship_tag": "ABX9TyM4DdZDGoP1xGst+Nn+rwvt", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "t9vk9Elugvmi" + }, "source": [ "# **Notebook 18.2: 1D Diffusion Model**\n", "\n", @@ -36,13 +25,15 @@ "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "\n", "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." - ], - "metadata": { - "id": "t9vk9Elugvmi" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OLComQyvCIJ7" + }, + "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", @@ -50,15 +41,15 @@ "from operator import itemgetter\n", "from scipy import stats\n", "from IPython.display import display, clear_output" - ], - "metadata": { - "id": "OLComQyvCIJ7" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4PM8bf6lO0VE" + }, + "outputs": [], "source": [ "#Create pretty colormap as in book\n", "my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", @@ -68,28 +59,28 @@ "b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", "my_colormap_vals = np.vstack((r,g,b)).transpose()/255.0\n", "my_colormap = ListedColormap(my_colormap_vals)" - ], - "metadata": { - "id": "4PM8bf6lO0VE" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ONGRaQscfIOo" + }, + "outputs": [], "source": [ "# Probability distribution for normal\n", "def norm_pdf(x, mu, sigma):\n", " return np.exp(-0.5 * (x-mu) * (x-mu) / (sigma * sigma)) / np.sqrt(2*np.pi*sigma*sigma)" - ], - "metadata": { - "id": "ONGRaQscfIOo" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZvG0MKhfY8Y" + }, + "outputs": [], "source": [ "# True distribution is a mixture of four Gaussians\n", "class TrueDataDistribution:\n", @@ -110,15 +101,15 @@ " mu_list = list(itemgetter(*hidden)(self.mu))\n", " sigma_list = list(itemgetter(*hidden)(self.sigma))\n", " return mu_list + sigma_list * epsilon" - ], - "metadata": { - "id": "gZvG0MKhfY8Y" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iJu_uBiaeUVv" + }, + "outputs": [], "source": [ "# Define ground truth probability distribution that we will model\n", "true_dist = TrueDataDistribution()\n", @@ -133,25 +124,26 @@ "ax.set_ylim(0,1.0)\n", "ax.set_xlim(-3,3)\n", "plt.show()" - ], - "metadata": { - "id": "iJu_uBiaeUVv" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "DRHUG_41i4t_" + }, "source": [ "To train the model to describe this distribution, we'll need to generate pairs of samples drawn from $Pr(z_t|x)$ (diffusion kernel) and $q(z_{t-1}|z_{t},x)$ (equation 18.15).\n", "\n" - ], - "metadata": { - "id": "DRHUG_41i4t_" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x6B8t72Ukscd" + }, + "outputs": [], "source": [ "# The diffusion kernel returns the parameters of Pr(z_{t}|x)\n", "def diffusion_kernel(x, t, beta):\n", @@ -180,24 +172,25 @@ " z_tminus1 = np.random.normal(size=x_train.shape) * cd_std + cd_mean\n", "\n", " return z_t, z_tminus1" - ], - "metadata": { - "id": "x6B8t72Ukscd" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "We also need models $\\mbox{f}_t[z_{t},\\phi_{t}]$ that map from $z_{t}$ to the mean of the distribution at time $z_{t-1}$. We're just going to use a very hacky non-parametric model (basically a lookup table) that tells you the result based on the (quantized) input." - ], "metadata": { "id": "aSG_4uA8_zZ-" - } + }, + "source": [ + "We also need models $\\text{f}_t[z_{t},\\phi_{t}]$ that map from $z_{t}$ to the mean of the distribution at time $z_{t-1}$. We're just going to use a very hacky non-parametric model (basically a lookup table) that tells you the result based on the (quantized) input." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZHViC0pL_yy5" + }, + "outputs": [], "source": [ "# This code is really ugly! Don't look too closely at it!\n", "# All you need to know is that it is a model that trains from pairs zt, zt_minus1\n", @@ -223,15 +216,15 @@ " bin_index = np.floor((zt+self.max_val)/self.inc)\n", " bin_index = np.clip(bin_index,0, len(self.model)-1).astype('uint32')\n", " return zt + self.model[bin_index]" - ], - "metadata": { - "id": "ZHViC0pL_yy5" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CzVFybWoBygu" + }, + "outputs": [], "source": [ "# Sample data from distribution (this would usually be our collected training set)\n", "n_sample = 100000\n", @@ -249,24 +242,25 @@ " all_models.append(NonParametricModel())\n", " # The model at index t maps data from z_{t+1} to z_{t}\n", " all_models[t].train(zt,zt_minus1)" - ], - "metadata": { - "id": "CzVFybWoBygu" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Now that we've learned the model, let's draw some samples from it. We start at $z_{100}$ and use the model to predict $z_{99}$, then $z_{98}$ and so on until finally we get to $z_{1}$ and then $x$ (represented as $z_{0}$ here). We'll store all of the intermediate stages as well, so we can plot the trajectories. See equations 18.16." - ], "metadata": { "id": "ZPc9SEvtl14U" - } + }, + "source": [ + "Now that we've learned the model, let's draw some samples from it. We start at $z_{100}$ and use the model to predict $z_{99}$, then $z_{98}$ and so on until finally we get to $z_{1}$ and then $x$ (represented as $z_{0}$ here). We'll store all of the intermediate stages as well, so we can plot the trajectories. See equations 18.16." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "A-ZMFOvACIOw" + }, + "outputs": [], "source": [ "def sample(model, T, sigma_t, n_samples):\n", " # Create the output array\n", @@ -295,24 +289,25 @@ " samples[t-1,:] = samples[t-1,:]\n", "\n", " return samples" - ], - "metadata": { - "id": "A-ZMFOvACIOw" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Now let's run the diffusion process for a whole bunch of samples" - ], "metadata": { "id": "ECAUfHNi9NVW" - } + }, + "source": [ + "Now let's run the diffusion process for a whole bunch of samples" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M-TY5w9Q8LYW" + }, + "outputs": [], "source": [ "sigma_t=0.12288\n", "n_samples = 100000\n", @@ -329,24 +324,25 @@ "plt.hist(sampled_data, bins=bins, density =True)\n", "ax.set_ylim(0, 0.8)\n", "plt.show()" - ], - "metadata": { - "id": "M-TY5w9Q8LYW" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Let's, plot the evolution of a few of the paths as in figure 18.7 (paths are from bottom to top now)." - ], "metadata": { "id": "jYrAW6tN-gJ4" - } + }, + "source": [ + "Let's, plot the evolution of a few of the paths as in figure 18.7 (paths are from bottom to top now)." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4XU6CDZC_kFo" + }, + "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "t_vals = np.arange(0,101,1)\n", @@ -360,21 +356,33 @@ "ax.set_xlabel('value')\n", "ax.set_ylabel('z_{t}')\n", "plt.show()" - ], - "metadata": { - "id": "4XU6CDZC_kFo" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Notice that the samples have a tendency to move from positions that are near the center at time 100 to positions that are high in the true probability distribution at time 0" - ], "metadata": { "id": "SGTYGGevAktz" - } + }, + "source": [ + "Notice that the samples have a tendency to move from positions that are near the center at time 100 to positions that are high in the true probability distribution at time 0" + ] } - ] -} \ No newline at end of file + ], + "metadata": { + "colab": { + "authorship_tag": "ABX9TyM4DdZDGoP1xGst+Nn+rwvt", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/Notebooks/Chap18/18_3_Reparameterized_Model.ipynb b/Notebooks/Chap18/18_3_Reparameterized_Model.ipynb index b141db0..eac07b7 100644 --- a/Notebooks/Chap18/18_3_Reparameterized_Model.ipynb +++ b/Notebooks/Chap18/18_3_Reparameterized_Model.ipynb @@ -1,33 +1,22 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "authorship_tag": "ABX9TyNd+D0/IVWXtU2GKsofyk2d", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "t9vk9Elugvmi" + }, "source": [ "# **Notebook 18.3: 1D Reparameterized model**\n", "\n", @@ -36,13 +25,15 @@ "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "\n", "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." - ], - "metadata": { - "id": "t9vk9Elugvmi" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OLComQyvCIJ7" + }, + "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", @@ -50,15 +41,15 @@ "from operator import itemgetter\n", "from scipy import stats\n", "from IPython.display import display, clear_output" - ], - "metadata": { - "id": "OLComQyvCIJ7" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4PM8bf6lO0VE" + }, + "outputs": [], "source": [ "#Create pretty colormap as in book\n", "my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", @@ -68,28 +59,28 @@ "b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", "my_colormap_vals = np.vstack((r,g,b)).transpose()/255.0\n", "my_colormap = ListedColormap(my_colormap_vals)" - ], - "metadata": { - "id": "4PM8bf6lO0VE" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ONGRaQscfIOo" + }, + "outputs": [], "source": [ "# Probability distribution for normal\n", "def norm_pdf(x, mu, sigma):\n", " return np.exp(-0.5 * (x-mu) * (x-mu) / (sigma * sigma)) / np.sqrt(2*np.pi*sigma*sigma)" - ], - "metadata": { - "id": "ONGRaQscfIOo" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZvG0MKhfY8Y" + }, + "outputs": [], "source": [ "# True distribution is a mixture of four Gaussians\n", "class TrueDataDistribution:\n", @@ -110,15 +101,15 @@ " mu_list = list(itemgetter(*hidden)(self.mu))\n", " sigma_list = list(itemgetter(*hidden)(self.sigma))\n", " return mu_list + sigma_list * epsilon" - ], - "metadata": { - "id": "gZvG0MKhfY8Y" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iJu_uBiaeUVv" + }, + "outputs": [], "source": [ "# Define ground truth probability distribution that we will model\n", "true_dist = TrueDataDistribution()\n", @@ -133,25 +124,26 @@ "ax.set_ylim(0,1.0)\n", "ax.set_xlim(-3,3)\n", "plt.show()" - ], - "metadata": { - "id": "iJu_uBiaeUVv" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "DRHUG_41i4t_" + }, "source": [ "To train the model to describe this distribution, we'll need to generate pairs of samples drawn from $Pr(z_t|x)$ (diffusion kernel) and $q(z_{t-1}|z_{t},x)$ (equation 18.15).\n", "\n" - ], - "metadata": { - "id": "DRHUG_41i4t_" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x6B8t72Ukscd" + }, + "outputs": [], "source": [ "# Return z_t (the argument of g_{t}[] in the loss function in algorithm 18.1) and epsilon\n", "def get_data_pairs(x_train,t,beta):\n", @@ -161,24 +153,25 @@ " z_t = np.ones_like(x_train)\n", "\n", " return z_t, epsilon" - ], - "metadata": { - "id": "x6B8t72Ukscd" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "We also need models $\\mbox{g}_t[z_{t},\\phi_{t}]$ that map from $z_{t}$ to the noise $\\epsilon$ that was added. We're just going to use a very hacky non-parametric model (basically a lookup table) that tells you the result based on the (quantized) input." - ], "metadata": { "id": "aSG_4uA8_zZ-" - } + }, + "source": [ + "We also need models $\\text{g}_t[z_{t},\\phi_{t}]$ that map from $z_{t}$ to the noise $\\epsilon$ that was added. We're just going to use a very hacky non-parametric model (basically a lookup table) that tells you the result based on the (quantized) input." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZHViC0pL_yy5" + }, + "outputs": [], "source": [ "# This code is really ugly! Don't look too closely at it!\n", "# All you need to know is that it is a model that trains from pairs zt, zt_minus1\n", @@ -204,15 +197,15 @@ " bin_index = np.floor((zt+self.max_val)/self.inc)\n", " bin_index = np.clip(bin_index,0, len(self.model)-1).astype('uint32')\n", " return self.model[bin_index]" - ], - "metadata": { - "id": "ZHViC0pL_yy5" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CzVFybWoBygu" + }, + "outputs": [], "source": [ "# Sample data from distribution (this would usually be our collected training set)\n", "n_sample = 100000\n", @@ -230,24 +223,25 @@ " all_models.append(NonParametricModel())\n", " # The model at index t maps data from z_{t+1} to epsilon\n", " all_models[t].train(zt,epsilon)" - ], - "metadata": { - "id": "CzVFybWoBygu" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Now that we've learned the model, let's draw some samples from it. We start at $z_{100}$ and use the model to predict $z_{99}$, then $z_{98}$ and so on until finally we get to $z_{1}$ and then $x$ (represented as $z_{0}$ here). We'll store all of the intermediate stages as well, so we can plot the trajectories. See algorithm 18.2" - ], "metadata": { "id": "ZPc9SEvtl14U" - } + }, + "source": [ + "Now that we've learned the model, let's draw some samples from it. We start at $z_{100}$ and use the model to predict $z_{99}$, then $z_{98}$ and so on until finally we get to $z_{1}$ and then $x$ (represented as $z_{0}$ here). We'll store all of the intermediate stages as well, so we can plot the trajectories. See algorithm 18.2" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "A-ZMFOvACIOw" + }, + "outputs": [], "source": [ "def sample(model, T, sigma_t, n_samples):\n", " # Create the output array\n", @@ -277,24 +271,25 @@ " samples[t-1,:] = samples[t-1,:]\n", "\n", " return samples" - ], - "metadata": { - "id": "A-ZMFOvACIOw" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Now let's run the diffusion process for a whole bunch of samples" - ], "metadata": { "id": "ECAUfHNi9NVW" - } + }, + "source": [ + "Now let's run the diffusion process for a whole bunch of samples" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M-TY5w9Q8LYW" + }, + "outputs": [], "source": [ "sigma_t=0.12288\n", "n_samples = 100000\n", @@ -311,24 +306,25 @@ "plt.hist(sampled_data, bins=bins, density =True)\n", "ax.set_ylim(0, 0.8)\n", "plt.show()" - ], - "metadata": { - "id": "M-TY5w9Q8LYW" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Let's, plot the evolution of a few of the paths as in figure 18.7 (paths are from bottom to top now)." - ], "metadata": { "id": "jYrAW6tN-gJ4" - } + }, + "source": [ + "Let's, plot the evolution of a few of the paths as in figure 18.7 (paths are from bottom to top now)." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4XU6CDZC_kFo" + }, + "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "t_vals = np.arange(0,101,1)\n", @@ -342,21 +338,33 @@ "ax.set_xlabel('value')\n", "ax.set_ylabel('z_{t}')\n", "plt.show()" - ], - "metadata": { - "id": "4XU6CDZC_kFo" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Notice that the samples have a tendency to move from positions that are near the center at time 100 to positions that are high in the true probability distribution at time 0" - ], "metadata": { "id": "SGTYGGevAktz" - } + }, + "source": [ + "Notice that the samples have a tendency to move from positions that are near the center at time 100 to positions that are high in the true probability distribution at time 0" + ] } - ] -} \ No newline at end of file + ], + "metadata": { + "colab": { + "authorship_tag": "ABX9TyNd+D0/IVWXtU2GKsofyk2d", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/Notebooks/Chap18/18_4_Families_of_Diffusion_Models.ipynb b/Notebooks/Chap18/18_4_Families_of_Diffusion_Models.ipynb index 68f9fdb..b01d7ff 100644 --- a/Notebooks/Chap18/18_4_Families_of_Diffusion_Models.ipynb +++ b/Notebooks/Chap18/18_4_Families_of_Diffusion_Models.ipynb @@ -1,33 +1,22 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "authorship_tag": "ABX9TyNFSvISBXo/Z1l+onknF2Gw", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "t9vk9Elugvmi" + }, "source": [ "# **Notebook 18.4: Families of diffusion models**\n", "\n", @@ -36,13 +25,15 @@ "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "\n", "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." - ], - "metadata": { - "id": "t9vk9Elugvmi" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OLComQyvCIJ7" + }, + "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", @@ -50,15 +41,15 @@ "from operator import itemgetter\n", "from scipy import stats\n", "from IPython.display import display, clear_output" - ], - "metadata": { - "id": "OLComQyvCIJ7" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4PM8bf6lO0VE" + }, + "outputs": [], "source": [ "#Create pretty colormap as in book\n", "my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", @@ -68,28 +59,28 @@ "b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", "my_colormap_vals = np.vstack((r,g,b)).transpose()/255.0\n", "my_colormap = ListedColormap(my_colormap_vals)" - ], - "metadata": { - "id": "4PM8bf6lO0VE" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ONGRaQscfIOo" + }, + "outputs": [], "source": [ "# Probability distribution for normal\n", "def norm_pdf(x, mu, sigma):\n", " return np.exp(-0.5 * (x-mu) * (x-mu) / (sigma * sigma)) / np.sqrt(2*np.pi*sigma*sigma)" - ], - "metadata": { - "id": "ONGRaQscfIOo" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZvG0MKhfY8Y" + }, + "outputs": [], "source": [ "# True distribution is a mixture of four Gaussians\n", "class TrueDataDistribution:\n", @@ -110,15 +101,15 @@ " mu_list = list(itemgetter(*hidden)(self.mu))\n", " sigma_list = list(itemgetter(*hidden)(self.sigma))\n", " return mu_list + sigma_list * epsilon" - ], - "metadata": { - "id": "gZvG0MKhfY8Y" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iJu_uBiaeUVv" + }, + "outputs": [], "source": [ "# Define ground truth probability distribution that we will model\n", "true_dist = TrueDataDistribution()\n", @@ -133,25 +124,26 @@ "ax.set_ylim(0,1.0)\n", "ax.set_xlim(-3,3)\n", "plt.show()" - ], - "metadata": { - "id": "iJu_uBiaeUVv" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "DRHUG_41i4t_" + }, "source": [ "To train the model to describe this distribution, we'll need to generate pairs of samples drawn from $Pr(z_t|x)$ (diffusion kernel) and $q(z_{t-1}|z_{t},x)$ (equation 18.15).\n", "\n" - ], - "metadata": { - "id": "DRHUG_41i4t_" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x6B8t72Ukscd" + }, + "outputs": [], "source": [ "# Return z_t (the argument of g_{t}[] in the loss function in algorithm 18.1) and epsilon\n", "def get_data_pairs(x_train,t,beta):\n", @@ -161,24 +153,25 @@ " z_t = x_train * np.sqrt(alpha_t) + np.sqrt(1-alpha_t) * epsilon\n", "\n", " return z_t, epsilon" - ], - "metadata": { - "id": "x6B8t72Ukscd" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "We also need models $\\mbox{g}_t[z_{t},\\phi_{t}]$ that map from $z_{t}$ to the noise $\\epsilon$ that was added. We're just going to use a very hacky non-parametric model (basically a lookup table) that tells you the result based on the (quantized) input." - ], "metadata": { "id": "aSG_4uA8_zZ-" - } + }, + "source": [ + "We also need models $\\text{g}_t[z_{t},\\phi_{t}]$ that map from $z_{t}$ to the noise $\\epsilon$ that was added. We're just going to use a very hacky non-parametric model (basically a lookup table) that tells you the result based on the (quantized) input." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZHViC0pL_yy5" + }, + "outputs": [], "source": [ "# This code is really ugly! Don't look too closely at it!\n", "# All you need to know is that it is a model that trains from pairs zt, zt_minus1\n", @@ -204,15 +197,15 @@ " bin_index = np.floor((zt+self.max_val)/self.inc)\n", " bin_index = np.clip(bin_index,0, len(self.model)-1).astype('uint32')\n", " return self.model[bin_index]" - ], - "metadata": { - "id": "ZHViC0pL_yy5" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CzVFybWoBygu" + }, + "outputs": [], "source": [ "# Sample data from distribution (this would usually be our collected training set)\n", "n_sample = 100000\n", @@ -230,15 +223,14 @@ " all_models.append(NonParametricModel())\n", " # The model at index t maps data from z_{t+1} to epsilon\n", " all_models[t].train(zt,epsilon)" - ], - "metadata": { - "id": "CzVFybWoBygu" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "ZPc9SEvtl14U" + }, "source": [ "Now that we've learned the model, let's draw some samples from it. We start at $z_{100}$ and use the model to predict $z_{99}$, then $z_{98}$ and so on until finally we get to $z_{1}$ and then $x$ (represented as $z_{0}$ here). We'll store all of the intermediate stages as well, so we can plot the trajectories.\n", "\n", @@ -247,17 +239,19 @@ "One such model is the denoising diffusion implicit model, which has a sampling step:\n", "\n", "\\begin{equation}\n", - "\\mathbf{z}_{t-1} = \\sqrt{\\alpha_{t-1}}\\left(\\frac{\\mathbf{z}_{t}-\\sqrt{1-\\alpha_{t}}\\mbox{g}_t[\\mathbf{z}_{t},\\boldsymbol\\phi]}{\\sqrt{\\alpha_{t}}}\\right) + \\sqrt{1-\\alpha_{t-1}-\\sigma^2}\\mbox{g}_t[\\mathbf{z}_{t},\\boldsymbol\\phi]+\\sigma\\epsilon\n", + "\\mathbf{z}_{t-1} = \\sqrt{\\alpha_{t-1}}\\left(\\frac{\\mathbf{z}_{t}-\\sqrt{1-\\alpha_{t}}\\text{g}_t[\\mathbf{z}_{t},\\boldsymbol\\phi]}{\\sqrt{\\alpha_{t}}}\\right) + \\sqrt{1-\\alpha_{t-1}-\\sigma^2}\\text{g}_t[\\mathbf{z}_{t},\\boldsymbol\\phi]+\\sigma\\epsilon\n", "\\end{equation}\n", "\n", "(see equation 12 of the denoising [diffusion implicit models paper ](https://arxiv.org/pdf/2010.02502.pdf).\n" - ], - "metadata": { - "id": "ZPc9SEvtl14U" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "A-ZMFOvACIOw" + }, + "outputs": [], "source": [ "def sample_ddim(model, T, sigma_t, n_samples):\n", " # Create the output array\n", @@ -283,24 +277,25 @@ " if t>0:\n", " samples[t-1,:] = samples[t-1,:]+ np.random.standard_normal(n_samples) * sigma_t\n", " return samples" - ], - "metadata": { - "id": "A-ZMFOvACIOw" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Now let's run the diffusion process for a whole bunch of samples" - ], "metadata": { "id": "ECAUfHNi9NVW" - } + }, + "source": [ + "Now let's run the diffusion process for a whole bunch of samples" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M-TY5w9Q8LYW" + }, + "outputs": [], "source": [ "# Now we'll set the noise to a MUCH smaller level\n", "sigma_t=0.001\n", @@ -318,24 +313,25 @@ "plt.hist(sampled_data, bins=bins, density =True)\n", "ax.set_ylim(0, 0.8)\n", "plt.show()" - ], - "metadata": { - "id": "M-TY5w9Q8LYW" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Let's, plot the evolution of a few of the paths as in figure 18.7 (paths are from bottom to top now)." - ], "metadata": { "id": "jYrAW6tN-gJ4" - } + }, + "source": [ + "Let's, plot the evolution of a few of the paths as in figure 18.7 (paths are from bottom to top now)." + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4XU6CDZC_kFo" + }, + "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "t_vals = np.arange(0,101,1)\n", @@ -349,35 +345,37 @@ "ax.set_xlabel('value')\n", "ax.set_ylabel('z_{t}')\n", "plt.show()" - ], - "metadata": { - "id": "4XU6CDZC_kFo" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "The samples have a tendency to move from positions that are near the center at time 100 to positions that are high in the true probability distribution at time 0" - ], "metadata": { "id": "SGTYGGevAktz" - } + }, + "source": [ + "The samples have a tendency to move from positions that are near the center at time 100 to positions that are high in the true probability distribution at time 0" + ] }, { + "attachments": {}, "cell_type": "markdown", + "metadata": { + "id": "Z-LZp_fMXxRt" + }, "source": [ "Let's now sample from the accelerated model, that requires fewer models. Again, we don't need to learn anything new -- this is just the reverse process that corresponds to a different forward process that is compatible with the same diffusion kernel.\n", "\n", "There's nothing to do here except read the code. It uses the same DDIM model as you just implemented in the previous step, but it jumps timesteps five at a time." - ], - "metadata": { - "id": "Z-LZp_fMXxRt" - } + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3Z0erjGbYj1u" + }, + "outputs": [], "source": [ "def sample_accelerated(model, T, sigma_t, n_steps, n_samples):\n", " # Create the output array\n", @@ -403,24 +401,25 @@ " if t>0:\n", " samples[c_step-1,:] = samples[c_step-1,:]+ np.random.standard_normal(n_samples) * sigma_t\n", " return samples" - ], - "metadata": { - "id": "3Z0erjGbYj1u" - }, - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", - "source": [ - "Now let's draw a bunch of samples from the model" - ], "metadata": { "id": "D3Sm_WYrcuED" - } + }, + "source": [ + "Now let's draw a bunch of samples from the model" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UB45c7VMcGy-" + }, + "outputs": [], "source": [ "sigma_t=0.11\n", "n_samples = 100000\n", @@ -438,15 +437,15 @@ "plt.hist(sampled_data, bins=bins, density =True)\n", "ax.set_ylim(0, 0.9)\n", "plt.show()" - ], - "metadata": { - "id": "UB45c7VMcGy-" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Luv-6w84c_qO" + }, + "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "step_increment = 100/ n_steps\n", @@ -464,21 +463,32 @@ "ax.set_xlabel('value')\n", "ax.set_ylabel('z_{t}')\n", "plt.show()" - ], - "metadata": { - "id": "Luv-6w84c_qO" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [], + "execution_count": null, "metadata": { "id": "LSJi72f0kw_e" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [] } - ] -} \ No newline at end of file + ], + "metadata": { + "colab": { + "authorship_tag": "ABX9TyNFSvISBXo/Z1l+onknF2Gw", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}