{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyP9fLqBQPgcYJB1KXs3Scp/", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# Gradient flow\n", "\n", "This notebook replicates some of the results in the Borealis AI [blog](https://www.borealisai.com/research-blogs/gradient-flow/) on gradient flow. \n" ], "metadata": { "id": "ucrRRJ4dq8_d" } }, { "cell_type": "code", "source": [ "# Import relevant libraries\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from scipy.linalg import expm\n", "from matplotlib import cm\n", "from matplotlib.colors import ListedColormap" ], "metadata": { "id": "_IQFHZEMZE8T" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Create the three data points that are used to train the linear model in the blog. Each input point is a column in $\\mathbf{X}$ and consists of the $x$ position in the plot and the value 1, which is used to allow the model to fit bias terms neatly." ], "metadata": { "id": "NwgUP3MSriiJ" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cJNZ2VIcYsD8" }, "outputs": [], "source": [ "X = np.array([[0.2, 0.4, 0.8],[1,1,1]])\n", "y = np.array([[-0.1],[0.15],[0.3]])\n", "D = X.shape[0]\n", "I = X.shape[1]\n", "\n", "print(\"X=\\n\",X)\n", "print(\"y=\\n\",y)" ] }, { "cell_type": "code", "source": [ "# Draw the three data points\n", "fig, ax = plt.subplots()\n", "ax.plot(X[0:1,:],y.T,'ro')\n", "ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n", "ax.set_xlabel('x'); ax.set_ylabel('y')\n", "plt.show()" ], "metadata": { "id": "FpFlD4nUZDRt" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Compute the evolution of the residuals, loss, and parameters as a function of time." ], "metadata": { "id": "H2LBR1DasQej" } }, { "cell_type": "code", "source": [ "# Discretized time to evaluate quantities at\n", "t_all = np.arange(0,20,0.01)\n", "nT = t_all.shape[0]\n", "\n", "# Initial parameters, and initial function output at training points\n", "phi_0 = np.array([[-0.05],[-0.4]])\n", "f_0 = X.T @ phi_0\n", "\n", "# Precompute pseudoinverse term (not a very sensible numerical implementation, but it works...)\n", "XXTInvX = np.linalg.inv(X@X.T)@X\n", "\n", "# Create arrays to hold function at data points over time, residual over time, parameters over time\n", "f_all = np.zeros((I,nT))\n", "f_minus_y_all = np.zeros((I,nT))\n", "phi_t_all = np.zeros((D,nT))\n", "\n", "# For each time, compute function, residual, and parameters at each time.\n", "for t in range(len(t_all)):\n", " f = y + expm(-X.T@X * t_all[t]) @ (f_0-y)\n", " f_all[:,t:t+1] = f\n", " f_minus_y_all[:,t:t+1] = f-y\n", " phi_t_all[:,t:t+1] = phi_0 - XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ (f_0-y)" ], "metadata": { "id": "wfF_oTS5Z4Wi" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Plot the results that were calculated in the previous cell" ], "metadata": { "id": "9jSjOOFutJUE" } }, { "cell_type": "code", "source": [ "# Plot function at data points\n", "fig, ax = plt.subplots()\n", "ax.plot(t_all,np.squeeze(f_all[0,:]),'r-', label='$f[x_{0},\\phi]$')\n", "ax.plot(t_all,np.squeeze(f_all[1,:]),'g-', label='$f[x_{1},\\phi]$')\n", "ax.plot(t_all,np.squeeze(f_all[2,:]),'b-', label='$f[x_{2},\\phi]$')\n", "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.5,0.5])\n", "ax.set_xlabel('t'); ax.set_ylabel('f')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()\n", "\n", "# Plot residual\n", "fig, ax = plt.subplots()\n", "ax.plot(t_all,np.squeeze(f_minus_y_all[0,:]),'r-', label='$f[x_{0},\\phi]-y_{0}$')\n", "ax.plot(t_all,np.squeeze(f_minus_y_all[1,:]),'g-', label='$f[x_{1},\\phi]-y_{1}$')\n", "ax.plot(t_all,np.squeeze(f_minus_y_all[2,:]),'b-', label='$f[x_{2},\\phi]-y_{2}$')\n", "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.5,0.5])\n", "ax.set_xlabel('t'); ax.set_ylabel('f-y')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()\n", "\n", "# Plot loss (sum of residuals)\n", "fig, ax = plt.subplots()\n", "square_error = 0.5 * np.sum(f_minus_y_all * f_minus_y_all, axis=0)\n", "ax.plot(t_all, square_error,'k-')\n", "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-0.0,0.25])\n", "ax.set_xlabel('t'); ax.set_ylabel('Loss')\n", "plt.show()\n", "\n", "# Plot parameters\n", "fig, ax = plt.subplots()\n", "ax.plot(t_all, np.squeeze(phi_t_all[0,:]),'c-',label='$\\phi_{0}$')\n", "ax.plot(t_all, np.squeeze(phi_t_all[1,:]),'m-',label='$\\phi_{1}$')\n", "ax.set_xlim([0,np.max(t_all)]); ax.set_ylim([-1,1])\n", "ax.set_xlabel('t'); ax.set_ylabel('$\\phi$')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ], "metadata": { "id": "G9IwgwKltHz5" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Define the model and the loss function" ], "metadata": { "id": "N6VaUq2swa8D" } }, { "cell_type": "code", "source": [ "# Model is just a straight line with intercept phi[0] and slope phi[1]\n", "def model(phi,x):\n", " y_pred = phi[0]+phi[1] * x\n", " return y_pred\n", "\n", "# Loss function is 0.5 times sum of squares of residuals for training data\n", "def compute_loss(data_x, data_y, model, phi):\n", " pred_y = model(phi, data_x)\n", " loss = 0.5 * np.sum((pred_y-data_y)*(pred_y-data_y))\n", " return loss" ], "metadata": { "id": "LGHEVUWWiB4f" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Draw the loss function" ], "metadata": { "id": "hr3hs7pKwo0g" } }, { "cell_type": "code", "source": [ "def draw_loss_function(compute_loss, X, y, model, phi_iters):\n", " # Define pretty colormap\n", " my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n", " my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n", " r = np.floor(my_colormap_vals_dec/(256*256))\n", " g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n", " b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n", " my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n", "\n", " # Make grid of intercept/slope values to plot\n", " intercepts_mesh, slopes_mesh = np.meshgrid(np.arange(-1.0,1.0,0.005), np.arange(-1.0,1.0,0.005))\n", " loss_mesh = np.zeros_like(slopes_mesh)\n", " # Compute loss for every set of parameters\n", " for idslope, slope in np.ndenumerate(slopes_mesh):\n", " loss_mesh[idslope] = compute_loss(X, y, model, np.array([[intercepts_mesh[idslope]], [slope]]))\n", "\n", " fig,ax = plt.subplots()\n", " fig.set_size_inches(8,8)\n", " ax.contourf(intercepts_mesh,slopes_mesh,loss_mesh,256,cmap=my_colormap)\n", " ax.contour(intercepts_mesh,slopes_mesh,loss_mesh,40,colors=['#80808080'])\n", " ax.set_ylim([1,-1]); ax.set_xlim([-1,1])\n", "\n", " ax.plot(phi_iters[1,:], phi_iters[0,:],'g-')\n", " ax.set_xlabel('Intercept'); ax.set_ylabel('Slope')\n", " plt.show()" ], "metadata": { "id": "UCxa3tZ8a9kz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "draw_loss_function(compute_loss, X[0:1,:], y.T, model, phi_t_all)" ], "metadata": { "id": "pXLLBaSaiI2A" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Draw the evolution of the function" ], "metadata": { "id": "ZsremHW-xFi5" } }, { "cell_type": "code", "source": [ "fig, ax = plt.subplots()\n", "ax.plot(X[0:1,:],y.T,'ro')\n", "x_vals = np.arange(0,1,0.001)\n", "ax.plot(x_vals, phi_t_all[0,0]*x_vals + phi_t_all[1,0],'r-', label='t=0.00')\n", "ax.plot(x_vals, phi_t_all[0,10]*x_vals + phi_t_all[1,10],'g-', label='t=0.10')\n", "ax.plot(x_vals, phi_t_all[0,30]*x_vals + phi_t_all[1,30],'b-', label='t=0.30')\n", "ax.plot(x_vals, phi_t_all[0,200]*x_vals + phi_t_all[1,200],'c-', label='t=2.00')\n", "ax.plot(x_vals, phi_t_all[0,1999]*x_vals + phi_t_all[1,1999],'y-', label='t=20.0')\n", "ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n", "ax.set_xlabel('x'); ax.set_ylabel('y')\n", "plt.legend(loc=\"upper left\")\n", "plt.show()" ], "metadata": { "id": "cv9ZrUoRkuhI" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Compute MAP and ML solutions\n", "MLParams = np.linalg.inv(X@X.T)@X@y\n", "sigma_sq_p = 3.0\n", "sigma_sq = 0.05\n", "MAPParams = np.linalg.inv(X@X.T+np.identity(X.shape[0])*sigma_sq/sigma_sq_p)@X@y" ], "metadata": { "id": "OU9oegSOof-o" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Finally, we predict both the mean and the uncertainty in the fitted model as a function of time" ], "metadata": { "id": "Ul__XvOgyYSA" } }, { "cell_type": "code", "source": [ "# Define x positions to make predictions (appending a 1 to each column)\n", "x_predict = np.arange(0,1,0.01)[None,:]\n", "x_predict = np.concatenate((x_predict,np.ones_like(x_predict)))\n", "nX = x_predict.shape[1]\n", "\n", "# Create variables to store evolution of mean and variance of prediction over time\n", "predict_mean_all = np.zeros((nT,nX))\n", "predict_var_all = np.zeros((nT,nX))\n", "\n", "# Initial covariance\n", "sigma_sq_p = 2.0\n", "cov_init = sigma_sq_p * np.identity(2)\n", "\n", "# Run through each time computing a and b and hence mean and variance of prediction\n", "for t in range(len(t_all)):\n", " a = x_predict.T @(XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ y)\n", " b = x_predict.T -x_predict.T@XXTInvX @ (np.identity(3)-expm(-X.T@X * t_all[t])) @ X.T\n", " predict_mean_all[t:t+1,:] = a.T\n", " predict_cov = b@ cov_init @b.T\n", " # We just want the diagonal of the covariance to plot the uncertainty\n", " predict_var_all[t:t+1,:] = np.reshape(np.diag(predict_cov),(1,nX))" ], "metadata": { "id": "aMPADCuByKWr" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Plot the mean and variance at various times" ], "metadata": { "id": "PZTj93KK7QH6" } }, { "cell_type": "code", "source": [ "def plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t, sigma_sq = 0.00001):\n", " fig, ax = plt.subplots()\n", " ax.plot(X[0:1,:],y.T,'ro')\n", " ax.plot(x_predict[0:1,:].T, predict_mean_all[this_t:this_t+1,:].T,'r-')\n", " lower = np.squeeze(predict_mean_all[this_t:this_t+1,:].T-np.sqrt(predict_var_all[this_t:this_t+1,:].T+np.sqrt(sigma_sq)))\n", " upper = np.squeeze(predict_mean_all[this_t:this_t+1,:].T+np.sqrt(predict_var_all[this_t:this_t+1,:].T+np.sqrt(sigma_sq)))\n", " ax.fill_between(np.squeeze(x_predict[0:1,:]), lower, upper, color='lightgray')\n", " ax.set_xlim([0,1]); ax.set_ylim([-0.5,0.5])\n", " ax.set_xlabel('x'); ax.set_ylabel('y')\n", " plt.show()\n", "\n", "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=0)\n", "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=40)\n", "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=80)\n", "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=200)\n", "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=500)\n", "plot_mean_var(X,y,x_predict, predict_mean_all, predict_var_all, this_t=1000)" ], "metadata": { "id": "bYAFxgB880-v" }, "execution_count": null, "outputs": [] } ] }