From f50de74496c806f93d5a9d66fedc7c2541904bf4 Mon Sep 17 00:00:00 2001 From: udlbook <110402648+udlbook@users.noreply.github.com> Date: Wed, 4 Oct 2023 17:03:39 +0100 Subject: [PATCH] Created using Colaboratory --- .../Chap15/15_2_Wasserstein_Distance.ipynb | 246 ++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 Notebooks/Chap15/15_2_Wasserstein_Distance.ipynb diff --git a/Notebooks/Chap15/15_2_Wasserstein_Distance.ipynb b/Notebooks/Chap15/15_2_Wasserstein_Distance.ipynb new file mode 100644 index 0000000..9314dba --- /dev/null +++ b/Notebooks/Chap15/15_2_Wasserstein_Distance.ipynb @@ -0,0 +1,246 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyNyLnpoXgKN+RGCuTUszCAZ", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Notebook 15.2: Wassserstein Distance**\n", + "\n", + "This notebook investigates the GAN toy example as illustred in figure 15.1 in the book.\n", + "\n", + "Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", + "\n", + "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions." + ], + "metadata": { + "id": "t9vk9Elugvmi" + } + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import cm\n", + "from matplotlib.colors import ListedColormap\n", + "from scipy.optimize import linprog" + ], + "metadata": { + "id": "OLComQyvCIJ7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Define two probability distributions\n", + "p = np.array([5, 3, 2, 1, 8, 7, 5, 9, 2, 1])\n", + "q = np.array([4, 10,1, 1, 4, 6, 3, 2, 0, 1])\n", + "p = p/np.sum(p);\n", + "q= q/np.sum(q);\n", + "\n", + "# Draw those distributions\n", + "fig, ax =plt.subplots(2,1);\n", + "x = np.arange(0,p.size,1)\n", + "ax[0].bar(x,p, color=\"#cccccc\")\n", + "ax[0].set_ylim([0,0.35])\n", + "ax[0].set_ylabel(\"p(x=i)\")\n", + "\n", + "ax[1].bar(x,q,color=\"#f47a60\")\n", + "ax[1].set_ylim([0,0.35])\n", + "ax[1].set_ylabel(\"q(x=j)\")\n", + "plt.show()" + ], + "metadata": { + "id": "ZIfQwhd-AV6L" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# TODO Define the distance matrix from figure 15.8d\n", + "# Replace this line\n", + "dist_mat = np.zeros((10,10))\n", + "\n", + "# vectorize the distance matrix\n", + "c = dist_mat.flatten()" + ], + "metadata": { + "id": "EZSlZQzWBKTm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Define pretty colormap\n", + "my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", + "my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n", + "r = np.floor(my_colormap_vals_dec/(256*256))\n", + "g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n", + "b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", + "my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n", + "\n", + "def draw_2D_heatmap(data, title, my_colormap):\n", + " # Make grid of intercept/slope values to plot\n", + " xv, yv = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 10))\n", + " fig,ax = plt.subplots()\n", + " fig.set_size_inches(4,4)\n", + " plt.imshow(data, cmap=my_colormap)\n", + " ax.set_title(title)\n", + " ax.set_xlabel('$q$'); ax.set_ylabel('$p$')\n", + " plt.show()" + ], + "metadata": { + "id": "ABRANmp6F8iQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "draw_2D_heatmap(dist_mat,'Distance $|i-j|$', my_colormap)" + ], + "metadata": { + "id": "G0HFPBXyHT6V" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Define b to be the verticalconcatenation of p and q\n", + "b = np.hstack((p,q))[np.newaxis].transpose()" + ], + "metadata": { + "id": "SfqeT3KlHWrt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# TODO: Now construct the matrix A that has the initial distribution constraints\n", + "# so that Ap=b where p is the transport plan P vectorized rows first so p = np.flatten(P)\n", + "# Replace this line:\n", + "A = np.zeros((20,100))\n" + ], + "metadata": { + "id": "7KrybL96IuNW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we have all of the things we need. The vectorized distance matrix $\\mathbf{c}$, the constraint matrix $\\mathbf{A}$, the vectorized and concatenated original distribution $\\mathbf{b}$. We can run the linear programming optimization." + ], + "metadata": { + "id": "zEuEtU33S8Ly" + } + }, + { + "cell_type": "code", + "source": [ + "# We don't need the constraint that p>0 as this is the default\n", + "opt = linprog(c, A_eq=A, b_eq=b)\n", + "print(opt)" + ], + "metadata": { + "id": "wCfsOVbeSmF5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Extract the answer and display" + ], + "metadata": { + "id": "vpkkOOI2agyl" + } + }, + { + "cell_type": "code", + "source": [ + "P = np.array(opt.x).reshape(10,10)\n", + "draw_2D_heatmap(P,'Transport plan $\\mathbf{P}$', my_colormap)" + ], + "metadata": { + "id": "nZGfkrbRV_D0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Compute the Wasserstein distance\n" + ], + "metadata": { + "id": "ZEiRYRVgalsJ" + } + }, + { + "cell_type": "code", + "source": [ + "was = np.sum(P * dist_mat)\n", + "print(\"Wasserstein distance = \", was)" + ], + "metadata": { + "id": "yiQ_8j-Raq3c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "TODO -- Compute the\n", + "\n", + "* Forward KL divergence $D_{KL}[p,q]$ between these distributions\n", + "* Reverse KL divergence $D_{KL}[q,p]$ between these distributions\n", + "* Jensen-Shannon divergence $D_{JS}[p,q]$ between these distributions\n", + "\n", + "What do you conclude?" + ], + "metadata": { + "id": "zf8yTusua71s" + } + } + ] +} \ No newline at end of file