From 934f5f77489114945efcf0da506c1ee07b226e4d Mon Sep 17 00:00:00 2001 From: udlbook <110402648+udlbook@users.noreply.github.com> Date: Wed, 24 Jan 2024 10:56:22 -0500 Subject: [PATCH] Created using Colaboratory --- Blogs/BorealisGradientFlow.ipynb | 401 +++++++++++++++++++++++++++++++ 1 file changed, 401 insertions(+) create mode 100644 Blogs/BorealisGradientFlow.ipynb diff --git a/Blogs/BorealisGradientFlow.ipynb b/Blogs/BorealisGradientFlow.ipynb new file mode 100644 index 0000000..bce503d --- /dev/null +++ b/Blogs/BorealisGradientFlow.ipynb @@ -0,0 +1,401 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyO6cFY1oR4CmbHL2QywgTXm", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Gradient flow\n", + "\n", + "This notebook replicates some of the results in the the Borealis AI blog on gradient flow. \n" + ], + "metadata": { + "id": "ucrRRJ4dq8_d" + } + }, + { + "cell_type": "code", + "source": [ + "# Import relevant libraries\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from scipy.linalg import expm\n", + "from matplotlib import cm\n", + "from matplotlib.colors import ListedColormap" + ], + "metadata": { + "id": "_IQFHZEMZE8T" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Create the three data points that are used to train the linear model in the blog. Each input point is a column in $\\mathbf{X}$ and consists of the $x$ position in the plot and the value 1, which is used to allow the model to fit bias terms neatly." + ], + "metadata": { + "id": "NwgUP3MSriiJ" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cJNZ2VIcYsD8" + }, + "outputs": [], + "source": [ + "X = np.array([[0.2, 0.4, 0.8],[1,1,1]])\n", + "y = np.array([[-0.1],[0.15],[0.3]])\n", + "D = X.shape[0]\n", + "I = X.shape[1]\n", + "\n", + "print(\"X=\\n\",X)\n", + "print(\"y=\\n\",y)" + ] + }, + { + "cell_type": "code", + "source": [ + "# Draw the three data points\n", + "fig, ax = plt.subplots()\n", + "ax.plot(X[0:1,:],y.T,'ro')\n", + "ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n", + "ax.set_xlabel('x'); ax.set_ylabel('y')\n", + "plt.show()" + ], + "metadata": { + "id": "FpFlD4nUZDRt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Compute the evolution of the residuals, loss, and parameters as a function of time." + ], + "metadata": { + "id": "H2LBR1DasQej" + } + }, + { + "cell_type": "code", + "source": [ + "# Discretized time to evaluate quantities at\n", + "t_all = np.arange(0,20,0.01)\n", + "nT = t_all.shape[0]\n", + "\n", + "# Initial parameters, and initial function output at training points\n", + "phi_0 = np.array([[-0.05],[-0.4]])\n", + "f_0 = X.T @ phi_0\n", + "\n", + "# Precompute pseudoinverse term (not a very sensible numerical implementation, but it works...)\n", + "XXTInvX = np.linalg.inv(X@X.T)@X\n", + "\n", + "# Create arrays to hold function at data points over time, residual over time, parameters over time\n", + "f_all = np.zeros((I,nT))\n", + "f_minus_y_all = np.zeros((I,nT))\n", + "phi_t_all = np.zeros((D,nT))\n", + "\n", + "# For each time, compute function, residual, and parameters at each time.\n", + "for t in range(len(t_all)):\n", + " f = y + expm(-X.T@X * t_all[t]) @ (f_0-y)\n", + " f_all[:,t:t+1] = f\n", + " f_minus_y_all[:,t:t+1] = f-y\n", + " phi_t_all[:,t:t+1] = phi_0 - XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ (f_0-y)" + ], + "metadata": { + "id": "wfF_oTS5Z4Wi" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Plot the results that were calculated in the previous cell" + ], + "metadata": { + "id": "9jSjOOFutJUE" + } + }, + { + "cell_type": "code", + "source": [ + "# Plot function at data points\n", + "fig, ax = plt.subplots()\n", + "ax.plot(t_all,np.squeeze(f_all[0,:]),'r-', label='$f[x_{0},\\phi]$')\n", + "ax.plot(t_all,np.squeeze(f_all[1,:]),'g-', label='$f[x_{1},\\phi]$')\n", + "ax.plot(t_all,np.squeeze(f_all[2,:]),'b-', label='$f[x_{2},\\phi]$')\n", + "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.5,0.5])\n", + "ax.set_xlabel('t'); ax.set_ylabel('f')\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()\n", + "\n", + "# Plot residual\n", + "fig, ax = plt.subplots()\n", + "ax.plot(t_all,np.squeeze(f_minus_y_all[0,:]),'r-', label='$f[x_{0},\\phi]-y_{0}$')\n", + "ax.plot(t_all,np.squeeze(f_minus_y_all[1,:]),'g-', label='$f[x_{1},\\phi]-y_{1}$')\n", + "ax.plot(t_all,np.squeeze(f_minus_y_all[2,:]),'b-', label='$f[x_{2},\\phi]-y_{2}$')\n", + "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.5,0.5])\n", + "ax.set_xlabel('t'); ax.set_ylabel('f-y')\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()\n", + "\n", + "# Plot loss (sum of residuals)\n", + "fig, ax = plt.subplots()\n", + "square_error = 0.5 * np.sum(f_minus_y_all * f_minus_y_all, axis=0)\n", + "ax.plot(t_all, square_error,'k-')\n", + "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.0,0.25])\n", + "ax.set_xlabel('t'); ax.set_ylabel('Loss')\n", + "plt.show()\n", + "\n", + "# Plot parameters\n", + "fig, ax = plt.subplots()\n", + "ax.plot(t_all, np.squeeze(phi_t_all[0,:]),'c-',label='$\\phi_{0}$')\n", + "ax.plot(t_all, np.squeeze(phi_t_all[1,:]),'m-',label='$\\phi_{1}$')\n", + "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-1,1])\n", + "ax.set_xlabel('t'); ax.set_ylabel('$\\phi$')\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()" + ], + "metadata": { + "id": "G9IwgwKltHz5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Define the model and the loss function" + ], + "metadata": { + "id": "N6VaUq2swa8D" + } + }, + { + "cell_type": "code", + "source": [ + "# Model is just a straight line with intercept phi[0] and slope phi[1]\n", + "def model(phi,x):\n", + " y_pred = phi[0]+phi[1] * x\n", + " return y_pred\n", + "\n", + "# Loss function is 0.5 times sum of squares of residuals for training data\n", + "def compute_loss(data_x, data_y, model, phi):\n", + " pred_y = model(phi, data_x)\n", + " loss = 0.5 * np.sum((pred_y-data_y)*(pred_y-data_y))\n", + " return loss" + ], + "metadata": { + "id": "LGHEVUWWiB4f" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Draw the loss function" + ], + "metadata": { + "id": "hr3hs7pKwo0g" + } + }, + { + "cell_type": "code", + "source": [ + "def draw_loss_function(compute_loss, X, y, model, phi_iters):\n", + " # Define pretty colormap\n", + " my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", + " my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n", + " r = np.floor(my_colormap_vals_dec/(256*256))\n", + " g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n", + " b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", + " my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n", + "\n", + " # Make grid of intercept/slope values to plot\n", + " intercepts_mesh, slopes_mesh = np.meshgrid(np.arange(-1.0,1.0,0.005), np.arange(-1.0,1.0,0.005))\n", + " loss_mesh = np.zeros_like(slopes_mesh)\n", + " # Compute loss for every set of parameters\n", + " for idslope, slope in np.ndenumerate(slopes_mesh):\n", + " loss_mesh[idslope] = compute_loss(X, y, model, np.array([[intercepts_mesh[idslope]], [slope]]))\n", + "\n", + " fig,ax = plt.subplots()\n", + " fig.set_size_inches(8,8)\n", + " ax.contourf(intercepts_mesh,slopes_mesh,loss_mesh,256,cmap=my_colormap)\n", + " ax.contour(intercepts_mesh,slopes_mesh,loss_mesh,40,colors=['#80808080'])\n", + " ax.set_ylim([1,-1]); ax.set_xlim([-1,1])\n", + "\n", + " ax.plot(phi_iters[1,:], phi_iters[0,:],'g-')\n", + " ax.set_xlabel('Intercept'); ax.set_ylabel('Slope')\n", + " plt.show()" + ], + "metadata": { + "id": "UCxa3tZ8a9kz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "draw_loss_function(compute_loss, X[0:1,:], y.T, model, phi_t_all)" + ], + "metadata": { + "id": "pXLLBaSaiI2A" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Draw the evolution of the function" + ], + "metadata": { + "id": "ZsremHW-xFi5" + } + }, + { + "cell_type": "code", + "source": [ + "fig, ax = plt.subplots()\n", + "ax.plot(X[0:1,:],y.T,'ro')\n", + "x_vals = np.arange(0,1,0.001)\n", + "ax.plot(x_vals, phi_t_all[0,0]*x_vals + phi_t_all[1,0],'r-', label='t=0.00')\n", + "ax.plot(x_vals, phi_t_all[0,10]*x_vals + phi_t_all[1,10],'g-', label='t=0.10')\n", + "ax.plot(x_vals, phi_t_all[0,30]*x_vals + phi_t_all[1,30],'b-', label='t=0.30')\n", + "ax.plot(x_vals, phi_t_all[0,200]*x_vals + phi_t_all[1,200],'c-', label='t=2.00')\n", + "ax.plot(x_vals, phi_t_all[0,1999]*x_vals + phi_t_all[1,1999],'y-', label='t=20.0')\n", + "ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n", + "ax.set_xlabel('x'); ax.set_ylabel('y')\n", + "plt.legend(loc=\"upper left\")\n", + "plt.show()" + ], + "metadata": { + "id": "cv9ZrUoRkuhI" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Compute MAP and ML solutions\n", + "MLParams = np.linalg.inv(X@X.T)@X@y\n", + "sigma_sq_p = 3.0\n", + "sigma_sq = 0.05\n", + "MAPParams = np.linalg.inv(X@X.T+np.identity(X.shape[0])*sigma_sq/sigma_sq_p)@X@y" + ], + "metadata": { + "id": "OU9oegSOof-o" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Finally, we predict both the mean and the uncertainty in the fitted model as a function of time" + ], + "metadata": { + "id": "Ul__XvOgyYSA" + } + }, + { + "cell_type": "code", + "source": [ + "# Define x positions to make predictions (appending a 1 to each column)\n", + "x_predict = np.arange(0,1,0.01)[None,:]\n", + "x_predict = np.concatenate((x_predict,np.ones_like(x_predict)))\n", + "nX = x_predict.shape[1]\n", + "\n", + "# Create variables to store evolution of mean and variance of prediction over time\n", + "predict_mean_all = np.zeros((nT,nX))\n", + "predict_var_all = np.zeros((nT,nX))\n", + "\n", + "# Initial covariance\n", + "sigma_sq_p = 2.0\n", + "cov_init = sigma_sq_p * np.identity(2)\n", + "\n", + "# Run through each time computing a and b and hence mean and variance of prediction\n", + "for t in range(len(t_all)):\n", + " a = x_predict.T @(XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ y)\n", + " b = x_predict.T -x_predict.T@XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ X.T\n", + " predict_mean_all[t:t+1,:] = a.T\n", + " predict_cov = b@ cov_init @b.T\n", + " # We just want the diagonal of the covariance to plot the uncertainty\n", + " predict_var_all[t:t+1,:] = np.reshape(np.diag(predict_cov),(1,nX))" + ], + "metadata": { + "id": "aMPADCuByKWr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Plot the mean and variance at various times" + ], + "metadata": { + "id": "PZTj93KK7QH6" + } + }, + { + "cell_type": "code", + "source": [ + "def plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t, sigma_sq = 0.00001):\n", + " fig, ax = plt.subplots()\n", + " ax.plot(X[0:1,:],y.T,'ro')\n", + " ax.plot(x_predict[0:1,:].T, predict_mean_all[this_t:this_t+1,:].T,'r-')\n", + " lower = np.squeeze(predict_mean_all[this_t:this_t+1,:].T-np.sqrt(predict_var_all[this_t:this_t+1,:].T+np.sqrt(sigma_sq)))\n", + " upper = np.squeeze(predict_mean_all[this_t:this_t+1,:].T+np.sqrt(predict_var_all[this_t:this_t+1,:].T+np.sqrt(sigma_sq)))\n", + " ax.fill_between(np.squeeze(x_predict[0:1,:]), lower, upper, color='lightgray')\n", + " ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n", + " ax.set_xlabel('x'); ax.set_ylabel('y')\n", + " plt.show()\n", + "\n", + "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=0)\n", + "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=40)\n", + "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=80)\n", + "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=200)\n", + "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=500)\n", + "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=1000)" + ], + "metadata": { + "id": "bYAFxgB880-v" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file