changed to ax.set_title(title) so title appears on plots

This commit is contained in:
tonyjo
2023-11-29 13:45:30 -05:00
parent 36d2695a41
commit aea371dc7d

View File

@@ -1,26 +1,10 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMpC8kgLnXx0XQBtwNAQ4jJ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap18/18_1_Diffusion_Encoder.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
@@ -28,6 +12,9 @@
},
{
"cell_type": "markdown",
"metadata": {
"id": "t9vk9Elugvmi"
},
"source": [
"# **Notebook 18.1: Diffusion Encoder**\n",
"\n",
@@ -36,27 +23,29 @@
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
],
"metadata": {
"id": "t9vk9Elugvmi"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OLComQyvCIJ7"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import ListedColormap\n",
"from operator import itemgetter"
],
"metadata": {
"id": "OLComQyvCIJ7"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4PM8bf6lO0VE"
},
"outputs": [],
"source": [
"#Create pretty colormap as in book\n",
"my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n",
@@ -66,28 +55,28 @@
"b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
"my_colormap_vals = np.vstack((r,g,b)).transpose()/255.0\n",
"my_colormap = ListedColormap(my_colormap_vals)"
],
"metadata": {
"id": "4PM8bf6lO0VE"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ONGRaQscfIOo"
},
"outputs": [],
"source": [
"# Probability distribution for normal\n",
"def norm_pdf(x, mu, sigma):\n",
" return np.exp(-0.5 * (x-mu) * (x-mu) / (sigma * sigma)) / np.sqrt(2*np.pi*sigma*sigma)"
],
"metadata": {
"id": "ONGRaQscfIOo"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gZvG0MKhfY8Y"
},
"outputs": [],
"source": [
"# True distribution is a mixture of four Gaussians\n",
"class TrueDataDistribution:\n",
@@ -108,15 +97,15 @@
" mu_list = list(itemgetter(*hidden)(self.mu))\n",
" sigma_list = list(itemgetter(*hidden)(self.sigma))\n",
" return mu_list + sigma_list * epsilon"
],
"metadata": {
"id": "gZvG0MKhfY8Y"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qXmej3TUuQyp"
},
"outputs": [],
"source": [
"# Define ground truth probability distribution that we will model\n",
"true_dist = TrueDataDistribution()\n",
@@ -130,24 +119,24 @@
"ax.set_ylim(0,1.0)\n",
"ax.set_xlim(-3,3)\n",
"plt.show()"
],
"metadata": {
"id": "qXmej3TUuQyp"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"Let's first implement the forward process"
],
"metadata": {
"id": "XHdtfRP47YLy"
}
},
"source": [
"Let's first implement the forward process"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hkApJ2VJlQuk"
},
"outputs": [],
"source": [
"# Do one step of diffusion (equation 18.1)\n",
"def diffuse_one_step(z_t_minus_1, beta_t):\n",
@@ -157,24 +146,24 @@
" z_t = np.zeros_like(z_t_minus_1)\n",
"\n",
" return z_t"
],
"metadata": {
"id": "hkApJ2VJlQuk"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"Now let's run the diffusion process for a whole bunch of samples"
],
"metadata": {
"id": "ECAUfHNi9NVW"
}
},
"source": [
"Now let's run the diffusion process for a whole bunch of samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M-TY5w9Q8LYW"
},
"outputs": [],
"source": [
"# Generate some samples\n",
"n_sample = 10000\n",
@@ -192,24 +181,24 @@
"\n",
"for t in range(T):\n",
" samples[t+1,:] = diffuse_one_step(samples[t,:], beta)"
],
"metadata": {
"id": "M-TY5w9Q8LYW"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"Let's, plot the evolution of a few paths as in figure 18.2"
],
"metadata": {
"id": "jYrAW6tN-gJ4"
}
},
"source": [
"Let's, plot the evolution of a few paths as in figure 18.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4XU6CDZC_kFo"
},
"outputs": [],
"source": [
"fig, ax = plt.subplots()\n",
"t_vals = np.arange(0,101,1)\n",
@@ -223,24 +212,24 @@
"ax.set_xlabel('value')\n",
"ax.set_ylabel('z_{t}')\n",
"plt.show()"
],
"metadata": {
"id": "4XU6CDZC_kFo"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"Notice that the samples have a tendency to move toward the center. Now let's look at the histogram of the samples at each stage"
],
"metadata": {
"id": "SGTYGGevAktz"
}
},
"source": [
"Notice that the samples have a tendency to move toward the center. Now let's look at the histogram of the samples at each stage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bn5E5NzL-evM"
},
"outputs": [],
"source": [
"def draw_hist(z_t,title=''):\n",
" fig, ax = plt.subplots()\n",
@@ -248,17 +237,17 @@
" plt.hist(z_t , bins=np.arange(-3,3, 0.1), density = True)\n",
" ax.set_xlim([-3,3])\n",
" ax.set_ylim([0,1.0])\n",
" ax.set_title('title')\n",
" ax.set_title(title)\n",
" plt.show()"
],
"metadata": {
"id": "bn5E5NzL-evM"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pn_XD-EhBlwk"
},
"outputs": [],
"source": [
"draw_hist(samples[0,:],'Original data')\n",
"draw_hist(samples[5,:],'Time step 5')\n",
@@ -267,33 +256,33 @@
"draw_hist(samples[40,:],'Time step 40')\n",
"draw_hist(samples[80,:],'Time step 80')\n",
"draw_hist(samples[100,:],'Time step 100')"
],
"metadata": {
"id": "pn_XD-EhBlwk"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"You can clearly see that as the diffusion process continues, the data becomes more Gaussian."
],
"metadata": {
"id": "skuLfGl5Czf4"
}
},
"source": [
"You can clearly see that as the diffusion process continues, the data becomes more Gaussian."
]
},
{
"cell_type": "markdown",
"source": [
"Now let's investigate the diffusion kernel as in figure 18.3 of the book.\n"
],
"metadata": {
"id": "s37CBSzzK7wh"
}
},
"source": [
"Now let's investigate the diffusion kernel as in figure 18.3 of the book.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vL62Iym0LEtY"
},
"outputs": [],
"source": [
"def diffusion_kernel(x, t, beta):\n",
" # TODO -- write this function\n",
@@ -301,15 +290,15 @@
" dk_mean = 0.0 ; dk_std = 1.0\n",
"\n",
" return dk_mean, dk_std"
],
"metadata": {
"id": "vL62Iym0LEtY"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KtP1KF8wMh8o"
},
"outputs": [],
"source": [
"def draw_prob_dist(x_plot_vals, prob_dist, title=''):\n",
" fig, ax = plt.subplots()\n",
@@ -363,47 +352,47 @@
" draw_prob_dist(x_plot_vals, diffusion_kernels[20,:],'$q(z_{20}|x)$')\n",
" draw_prob_dist(x_plot_vals, diffusion_kernels[40,:],'$q(z_{40}|x)$')\n",
" draw_prob_dist(x_plot_vals, diffusion_kernels[80,:],'$q(z_{80}|x)$')"
],
"metadata": {
"id": "KtP1KF8wMh8o"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"source": [
"x = -2\n",
"compute_and_plot_diffusion_kernels(x, T, beta, my_colormap)"
],
"execution_count": null,
"metadata": {
"id": "g8TcI5wtRQsx"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"x = -2\n",
"compute_and_plot_diffusion_kernels(x, T, beta, my_colormap)"
]
},
{
"cell_type": "markdown",
"source": [
"TODO -- Run this for different version of $x$ and check that you understand how the graphs change"
],
"metadata": {
"id": "-RuN2lR28-hK"
}
},
"source": [
"TODO -- Run this for different version of $x$ and check that you understand how the graphs change"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n-x6Whz2J_zy"
},
"source": [
"Finally, let's estimate the marginal distributions empirically and visualize them as in figure 18.4 of the book. This is only tractable because the data is in one dimension and we know the original distribution.\n",
"\n",
"The marginal distribution at time t is the sum of the diffusion kernels for each position x, weighted by the probability of seeing that value of x in the true distribution."
],
"metadata": {
"id": "n-x6Whz2J_zy"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YzN5duYpg7C-"
},
"outputs": [],
"source": [
"def diffusion_marginal(x_plot_vals, pr_x_true, t, beta):\n",
" # If time is zero then marginal is just original distribution\n",
@@ -427,15 +416,15 @@
"\n",
"\n",
" return marginal_at_time_t"
],
"metadata": {
"id": "YzN5duYpg7C-"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OgEU9sxjRaeO"
},
"outputs": [],
"source": [
"x_plot_vals = np.arange(-3,3,0.01)\n",
"marginal_distributions = np.zeros((T+1,len(x_plot_vals)))\n",
@@ -460,12 +449,23 @@
"draw_prob_dist(x_plot_vals, marginal_distributions[0,:],'$q(z_{0})$')\n",
"draw_prob_dist(x_plot_vals, marginal_distributions[20,:],'$q(z_{20})$')\n",
"draw_prob_dist(x_plot_vals, marginal_distributions[60,:],'$q(z_{60})$')"
]
}
],
"metadata": {
"id": "OgEU9sxjRaeO"
"colab": {
"authorship_tag": "ABX9TyMpC8kgLnXx0XQBtwNAQ4jJ",
"include_colab_link": true,
"provenance": []
},
"execution_count": null,
"outputs": []
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
]
},
"nbformat": 4,
"nbformat_minor": 0
}