Files
udlbook/Notebooks/Chap17/17_1_Latent_Variable_Models.ipynb
2025-08-26 07:09:18 -04:00

395 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap17/17_1_Latent_Variable_Models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t9vk9Elugvmi"
},
"source": [
"# **Notebook 17.1: Latent variable models**\n",
"\n",
"This notebook investigates a non-linear latent variable model similar to that in figures 17.2 and 17.3 of the book.\n",
"\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OLComQyvCIJ7"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import scipy\n",
"from matplotlib.colors import ListedColormap\n",
"from matplotlib import cm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IyVn-Gi-p7wf"
},
"source": [
"We'll assume that our base distribution over the latent variables is a 1D standard normal so that\n",
"\n",
"\\begin{equation}\n",
"Pr(z) = \\text{Norm}_{z}[0,1]\n",
"\\end{equation}\n",
"\n",
"As in figure 17.2, we'll assume that the output is two dimensional, we need to define a function that maps from the 1D latent variable to two dimensions. Usually, we would use a neural network, but in this case, we'll just define an arbitrary relationship.\n",
"\n",
"\\begin{align}\n",
"x_{1} &=& 0.5\\cdot\\exp\\Bigl[\\sin\\bigl[2+ 3.675 z \\bigr]\\Bigr]\\\\\n",
"x_{2} &=& \\sin\\bigl[2+ 2.85 z \\bigr]\n",
"\\end{align}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZIfQwhd-AV6L"
},
"outputs": [],
"source": [
"# The function that maps z to x1 and x2\n",
"def f(z):\n",
" x_1 = np.exp(np.sin(2+z*3.675)) * 0.5\n",
" x_2 = np.cos(2+z*2.85)\n",
" return x_1, x_2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KB9FU34onW1j"
},
"source": [
"Let's plot the 3D relation between the two observed variables $x_{1}$ and $x_{2}$ and the latent variables $z$ as in figure 17.2 of the book. We'll use the opacity to represent the prior probability $Pr(z)$."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lW08xqAgnP4q"
},
"outputs": [],
"source": [
"def draw_3d_projection(z,pr_z, x1,x2):\n",
" alpha = pr_z / np.max(pr_z)\n",
" ax = plt.axes(projection='3d')\n",
" fig = plt.gcf()\n",
" fig.set_size_inches(18.5, 10.5)\n",
" for i in range(len(z)-1):\n",
" ax.plot([z[i],z[i+1]],[x1[i],x1[i+1]],[x2[i],x2[i+1]],'r-', alpha=pr_z[i])\n",
" ax.set_xlabel('$z$',)\n",
" ax.set_ylabel('$x_1$')\n",
" ax.set_zlabel('$x_2$')\n",
" ax.set_xlim(-3,3)\n",
" ax.set_ylim(0,2)\n",
" ax.set_zlim(-1,1)\n",
" ax.set_box_aspect((3,1,1))\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9DUTauMi6tPk"
},
"outputs": [],
"source": [
"# Compute the prior\n",
"def get_prior(z):\n",
" return scipy.stats.multivariate_normal.pdf(z)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PAzHq461VqvF"
},
"outputs": [],
"source": [
"# Define the latent variable values\n",
"z = np.arange(-3.0,3.0,0.01)\n",
"# Find the probability distribution over z\n",
"pr_z = get_prior(z)\n",
"# Compute x1 and x2 for each z\n",
"x1,x2 = f(z)\n",
"# Plot the function\n",
"draw_3d_projection(z,pr_z, x1,x2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQg2gKR5zMrF"
},
"source": [
"The likelihood is defined as:\n",
"\\begin{align}\n",
" Pr(x_1,x_2|z) &=& \\text{Norm}_{[x_1,x_2]}\\Bigl[\\mathbf{f}[z],\\sigma^{2}\\mathbf{I}\\Bigr]\n",
"\\end{align}\n",
"\n",
"so we will also need to define the noise level $\\sigma^2$"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "In_Vg4_0nva3"
},
"outputs": [],
"source": [
"sigma_sq = 0.04"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6P6d-AgAqxXZ"
},
"outputs": [],
"source": [
"# Draws a heatmap to represent a probability distribution, possibly with samples overlaed\n",
"def plot_heatmap(x1_mesh,x2_mesh,y_mesh, x1_samples=None, x2_samples=None, title=None):\n",
" # Define pretty colormap\n",
" my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n",
" my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n",
" r = np.floor(my_colormap_vals_dec/(256*256))\n",
" g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n",
" b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
" my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n",
"\n",
" fig,ax = plt.subplots()\n",
" fig.set_size_inches(8,8)\n",
" ax.contourf(x1_mesh,x2_mesh,y_mesh,256,cmap=my_colormap)\n",
" ax.contour(x1_mesh,x2_mesh,y_mesh,8,colors=['#80808080'])\n",
" if title is not None:\n",
" ax.set_title(title);\n",
" if x1_samples is not None:\n",
" ax.plot(x1_samples, x2_samples, 'c.')\n",
" ax.set_xlim([-0.5,2.5])\n",
" ax.set_ylim([-1.5,1.5])\n",
" ax.set_xlabel('$x_1$'); ax.set_ylabel('$x_2$')\n",
" plt.show()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "diYKb7_ZgjlJ"
},
"outputs": [],
"source": [
"# Returns the likelihood\n",
"def get_likelihood(x1_mesh, x2_mesh, z_val):\n",
" # Find the corresponding x1 and x2 values\n",
" x1,x2 = f(z_val)\n",
"\n",
" # Calculate the probability for a mesh of x1,x2 values.\n",
" mn = scipy.stats.multivariate_normal([x1, x2], [[sigma_sq, 0], [0, sigma_sq]])\n",
" pr_x1_x2_given_z_val = mn.pdf(np.dstack((x1_mesh, x2_mesh)))\n",
" return pr_x1_x2_given_z_val"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0X4NwixzqxtZ"
},
"source": [
"Now let's plot the likelihood $Pr(x_1,x_2|z)$ as in fig 17.3b in the book."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hWfqK-Oz5_DT"
},
"outputs": [],
"source": [
"# Choose some z value\n",
"z_val = 1.8\n",
"\n",
"# Compute the conditional distribution on a grid\n",
"x1_mesh, x2_mesh = np.meshgrid(np.arange(-0.5,2.5,0.01), np.arange(-1.5,1.5,0.01))\n",
"pr_x1_x2_given_z_val = get_likelihood(x1_mesh,x2_mesh, z_val)\n",
"\n",
"# Plot the result\n",
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2_given_z_val, title=\"Conditional distribution $Pr(x_1,x_2|z)$\")\n",
"\n",
"# TODO -- Experiment with different values of z and make sure that you understand the what is happening."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "25xqXnmFo-PH"
},
"source": [
"The data density is found by marginalizing over the latent variables $z$:\n",
"\n",
"\\begin{align}\n",
" Pr(x_1,x_2) &=& \\int Pr(x_1,x_2, z) dz \\nonumber \\\\\n",
" &=& \\int Pr(x_1,x_2 | z) \\cdot Pr(z)dz\\nonumber \\\\\n",
" &=& \\int \\text{Norm}_{[x_1,x_2]}\\Bigl[\\mathbf{f}[z],\\sigma^{2}\\mathbf{I}\\Bigr]\\cdot \\text{Norm}_{z}\\left[\\mathbf{0},\\mathbf{I}\\right]dz.\n",
"\\end{align}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "H0Ijce9VzeCO"
},
"outputs": [],
"source": [
"# TODO Compute the data density\n",
"# We can't integrate this function in closed form\n",
"# So let's approximate it as a sum over the z values (z = np.arange(-3,3,0.01))\n",
"# You will need the functions get_likelihood() and get_prior()\n",
"# To make this a valid probability distribution, you need to multiply\n",
"# By the z-increment (0.01)\n",
"# Replace this line\n",
"pr_x1_x2 = np.zeros_like(x1_mesh)\n",
"\n",
"\n",
"# Plot the result\n",
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2, title=\"Data density $Pr(x_1,x_2)$\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W264N7By_h9y"
},
"source": [
"Now let's draw some samples from the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Li3mK_I48k0k"
},
"outputs": [],
"source": [
"def draw_samples(n_sample):\n",
" # TODO Write this routine to draw n_sample samples from the model\n",
" # First draw a random value of z from the prior (a standard normal distribution)\n",
" # Then draw a sample from Pr(x1,x2|z)\n",
" # Replace this line\n",
" x1_samples=0; x2_samples = 0;\n",
"\n",
" return x1_samples, x2_samples"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D7N7oqLe-eJO"
},
"source": [
"Let's plot those samples on top of the heat map."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XRmWv99B-BWO"
},
"outputs": [],
"source": [
"x1_samples, x2_samples = draw_samples(500)\n",
"# Plot the result\n",
"plot_heatmap(x1_mesh, x2_mesh, pr_x1_x2, x1_samples, x2_samples, title=\"Data density $Pr(x_1,x_2)$\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PwOjzPD5_1OF"
},
"outputs": [],
"source": [
"# Return the posterior distribution\n",
"def get_posterior(x1,x2):\n",
" z = np.arange(-3,3, 0.01)\n",
" # TODO -- write this function\n",
" # Again, we can't integrate, but we can sum\n",
" # We don't know the constant in the denominator of equation 17.19, but we can just normalize\n",
" # by the sum of the numerators for all values of z\n",
" # Replace this line:\n",
" pr_z_given_x1_x2 = np.ones_like(z)\n",
"\n",
"\n",
" return z, pr_z_given_x1_x2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PKFUY42K-Tp7"
},
"outputs": [],
"source": [
"x1 = 0.9; x2 = -0.9\n",
"z, pr_z_given_x1_x2 = get_posterior(x1,x2)\n",
"\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.plot(z, pr_z_given_x1_x2, 'r-')\n",
"ax.set_xlabel(\"Latent variable $z$\")\n",
"ax.set_ylabel(\"Posterior probability $Pr(z|x_{1},x_{2})$\")\n",
"ax.set_xlim([-3,3])\n",
"ax.set_ylim([0,1.5 * np.max(pr_z_given_x1_x2)])\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}