326 lines
27 KiB
Plaintext
326 lines
27 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": [],
|
|
"authorship_tag": "ABX9TyMD3zdteWU9gy7nYCZvDeGT",
|
|
"include_colab_link": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Trees/LinearRegression_FitModel.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"<img src=\"data:image/svg+xml;base64,<?xml version="1.0" encoding="UTF-8"?>
<svg width="764.45" height="63.759" version="1.1" viewBox="0 0 764.45 63.759" xmlns="http://www.w3.org/2000/svg">
 <g transform="matrix(.73548 0 0 .73548 0 3.388)" stroke-width="1.3597">
  <rect x="5e-7" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="10.330567" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="10.330567" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">I </tspan></text>
  <rect x="96" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="106.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="106.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">C </tspan></text>
  <rect x="180" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="190.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="190.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">L </tspan></text>
  <rect x="264" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="274.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="274.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">I </tspan></text>
  <rect x="348" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="358.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="358.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">M </tspan></text>
  <rect x="432" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="442.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="442.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">B </tspan></text>
  <g transform="translate(2 1.5376)">
   <rect x="624" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="634.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="634.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">T </tspan></text>
   <rect x="708" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="718.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="718.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">R </tspan></text>
   <rect x="792" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="802.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="802.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">E</tspan></text>
   <rect x="876" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="886.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="886.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">E </tspan></text>
   <rect x="960" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="970.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="970.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">S </tspan></text>
  </g>
  <g transform="matrix(1.0499 0 0 1.0499 -28.092 -.27293)" fill="#4469d8" stroke="#fdffff">
   <rect x="528" y="-1.5376" width="77.387" height="77.387" ry="11.35" opacity=".98" stroke-width="5.18"/>
   <g transform="matrix(.74592 0 0 .74367 530.84 1.6744)" stroke-width="5.2162" featureKey="inlineSymbolFeature-0">
    <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m47.659 81.427c0.358-7.981 1.333-15.917 1.152-23.917-0.01-0.425-0.544-0.843-0.94-0.54-2.356 1.801-4.811 3.219-7.664 4.104-3.649 1.132-7.703-2.328-5.814-5.981 0.758-1.466 2.146-2.708 3.447-3.672 0.467-0.346 0.358-1.176-0.315-1.165-3.154 0.054-10.835 1.149-10.042-4.386 0.481-3.365 6.29-5.458 8.917-6.84 0.333-0.175 0.435-0.73 0.127-0.981-6.663-5.431-3.069-14.647 5.731-12.788 0.272 0.058 0.563-0.033 0.706-0.287 2.235-3.995 4.276-8.063 7.106-11.688-0.356-0.147-0.712-0.294-1.067-0.442 0.294 3.116 2.036 5.269 4.337 7.272 2.459 2.142 7.634 4.27 8.085 7.845 0.481 3.821-6.549 4.356-6.054 7.588 0.33 2.147 1.354 3.423 3.021 4.74 1.052 0.831 1.968 1.405 3.017 2.329 1.818 2.036 1.596 4.223-0.667 6.561-1.486 0.252-2.927 0.138-4.32-0.341-0.556-0.144-0.945 0.435-0.706 0.918 1.412 2.842 3.23 5.449 3.529 8.707 0.821 8.969-7.237 1.748-8.13 0.875-0.813-0.793-1.6-1.561-2.486-2.27-0.623-0.498-1.514 0.38-0.885 0.884 3.399 2.717 6.507 7.782 11.132 4.42 4.323-3.142-0.524-10.114-2.08-13.246-0.235 0.306-0.471 0.612-0.706 0.918 3.9 1.01 8.231 0.447 7.941-4.452-0.117-1.973-1.259-3.644-2.8-4.778-1.468-1.081-6.729-4.234-3.68-6.41 1.261-0.899 2.453-1.826 3.548-2.929 2.294-2.311 1.726-4.94-0.326-7.105-3.535-3.732-9.97-5.682-10.521-11.525-0.044-0.47-0.692-0.921-1.067-0.442-1.267 1.622-6.265 11.724-7.841 11.391-2.234-0.472-4.485 0.06-6.418 1.186-4.105 2.391-3.919 7.903-1.738 11.448 0.122 0.199 1.517 2.084 1.782 1.944-1.682 0.885-3.351 1.737-4.951 2.768-1.664 1.072-4.177 3.262-3.904 5.54 0.671 5.619 7.144 4.902 11.409 4.829-0.105-0.388-0.21-0.776-0.315-1.165-3.56 2.636-8.58 11.381-0.562 12.174 2.34 0.231 4.247-0.259 6.423-1.142 0.883-0.358 1.698-0.845 2.525-1.311 0.775-0.437 1.976-2.122 2.008-0.692 0.166 7.357-0.865 14.714-1.194 22.056-0.036 0.804 1.214 0.801 1.25-2e-3z" fill="#4469d8" stroke="#fdffff" stroke-linejoin="round" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m22.301 83.156c-0.441-6.032-1.072-12.618 0.266-18.564 0.138-0.613-0.578-1.042-1.045-0.608-1.743 1.625-3.443 2.831-5.732 3.604-6.34-3.393-7.913-6.373-4.717-8.939 0.988-0.856 2.034-1.633 3.139-2.329 0.287-0.191 0.397-0.544 0.225-0.855-0.658-1.178-1.392-2.163-2.251-3.191-4.397-5.264-0.382-9.414 4.759-10.875 0.271-0.077 0.455-0.322 0.459-0.603 0.036-2.864 0.313-5.642 1.094-8.407 1.865-6.606 10.255-9.181 13.143-1.487 0.28 0.748 1.489 0.424 1.205-0.332-2.517-6.706-9.574-7.649-13.918-2.003-2.305 2.996-2.61 7.466-2.759 11.084-0.035 0.85-3.839 2.269-4.496 2.694-1.034 0.669-2.219 2.098-2.45 3.312-0.808 4.233 1.103 6.056 3.512 9.323 0.405 0.548-5.327 5.252-5.317 7.279 0.016 3.468 2.455 5.64 5.605 6.645 3.404 1.086 7.127-1.932 9.386-4.037-0.349-0.203-0.697-0.405-1.045-0.608-1.368 6.079-0.762 12.734-0.311 18.896 0.056 0.8 1.306 0.806 1.248 1e-3z" fill="#4469d8" stroke="#fdffff" stroke-linejoin="round" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m21.424 64.741c1.983 2.707 4.981 4.199 8.349 3.637 3.594-0.6 5.191-4.13 5.291-7.411 0.024-0.807-1.226-0.804-1.25 0-0.202 6.67-7.523 8.313-11.31 3.143-0.472-0.643-1.557-0.02-1.08 0.631z" fill="#4469d8" stroke="#fdffff" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m74.661 80.878c2.869-5.406 3.251-12.191 2.679-18.182-0.036-0.381-0.375-0.742-0.791-0.603-1.482 0.496-9.677 1.84-5.634-4.557 0.251-0.397-0.075-0.952-0.54-0.94-4.913 0.123-9.233-0.937-9.57-6.683-0.047-0.801-1.297-0.806-1.25 0 0.201 3.426 1.375 5.828 4.622 7.214 1.514 0.646 3.278 0.7 4.894 0.751-0.658-0.021-0.338 3.074-0.216 3.489 0.625 2.13 4.101 2.773 5.896 2.466 2.606-0.446 1.551 3.288 1.477 5.177-0.15 3.833-0.832 7.82-2.646 11.236-0.378 0.713 0.701 1.345 1.079 0.632z" fill="#4469d8" stroke="#fdffff" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m76.881 63.299c3.341-0.618 7.425-1.372 7.423-5.67 0-1.473-0.141-3.462-1.403-4.486 0.524 0.425 2.703-1.287 3.381-1.885 5.097-4.499 1.607-12.585-4.301-13.85-0.222-0.047 2.216-4.5 2.515-5.157 0.832-1.834 0.614-3.634-8e-3 -5.472-1.133-3.347-6.327-9.06-10.153-9.283-1.411-0.082-2.449-0.077-3.515 0.881-1.212 1.09 0.842 3.98-1.963 2.484-4.82-2.573-5.125 2.25-7.856 4.852-0.584 0.557 0.301 1.439 0.885 0.884 1.199-1.143 0.961-0.736 1.574-2.026 2.202-4.641 4.768-2.589 7.178-1.388 0.334 0.167 0.839 0.047 0.918-0.374 0.208-1.098 0.205-1.025 0.186-2.169 2.787-1.84 5.084-1.596 6.891 0.731 0.745 0.715 1.449 1.469 2.113 2.261 4.874 5.507 2.097 8.833-0.535 13.968-0.228 0.445 0.06 0.897 0.54 0.94 8.368 0.749 8.684 11.983 0.698 13.757-0.432 0.096-0.64 0.75-0.276 1.044 4.99 4.046-0.386 7.969-4.622 8.753-0.794 0.147-0.458 1.351 0.33 1.205z" fill="#4469d8" stroke="#fdffff" stroke-linejoin="round" stroke-width="5.2162"/>
     </g>
    </g>
   </g>
  </g>
 </g>
</svg>
\">"
|
|
],
|
|
"metadata": {
|
|
"id": "9Qgup23vwdll"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Fitting 1D regression model\n",
|
|
"\n",
|
|
"The purpose of this Python notebook experiment with fitting the 1D regression model with a least squares loss using coordinate descent.\n",
|
|
"\n",
|
|
"Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and write code to complete the functions.\n",
|
|
"\n",
|
|
"You can save a local copy of this notebook in your Google account and work through it in Colab (recommended) or you can download the notebook and run it locally using Jupyter notebook or similar. If you are using CoLab, we recommend that *turn off* AI autocomplete (under cog icon in top-right corner), which will give you the answers and defeat the purpose of the exercise.\n",
|
|
"\n",
|
|
"A fully working version of this notebook with the complete answers can be found [here](https://https://colab.research.google.com/github/udlbook/udlbook/blob/main/Trees/LinearRegression_LossFunction_Answers.ipynb).\n",
|
|
"\n",
|
|
"Contact me at iclimbtreesmail@gmail.com if you find any mistakes or have any suggestions."
|
|
],
|
|
"metadata": {
|
|
"id": "uORlKyPv02ge"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "bbF6SE_F0tU8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Math library\n",
|
|
"import numpy as np\n",
|
|
"# Plotting library\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from matplotlib import cm\n",
|
|
"from matplotlib.colors import ListedColormap\n",
|
|
"# Time library\n",
|
|
"import time\n",
|
|
"# Used to update figures\n",
|
|
"from IPython import display"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Create the same input / output data as used in the unit\n",
|
|
"x = np.array([0.03, 0.19, 0.34, 0.46, 0.78, 0.81, 1.08, 1.18, 1.39, 1.60, 1.65, 1.90])\n",
|
|
"y = np.array([0.67, 0.85, 1.05, 1.0, 1.40, 1.5, 1.3, 1.54, 1.55, 1.68, 1.73, 1.6 ])"
|
|
],
|
|
"metadata": {
|
|
"id": "9fGAobBnyI7Z"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Define the model and the least squares loss\n",
|
|
"\n",
|
|
"The linear regression model is defined as:\n",
|
|
"\n",
|
|
"$$\\textrm{f}[x,\\boldsymbol\\phi] = \\phi_0+\\phi_1 x$$\n",
|
|
"\n",
|
|
"where $\\phi_0$ is the y-intercept and $\\phi_1$ is the slope."
|
|
],
|
|
"metadata": {
|
|
"id": "FylovB6YyhWA"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"def f(x, phi0, phi1):\n",
|
|
" return phi0 + phi1 * x"
|
|
],
|
|
"metadata": {
|
|
"id": "fpgM_LstyLwt"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"\n",
|
|
"The least squares loss is defined as the sum of the squared deviations of the model output $\\textrm{f}[x_i,\\boldsymbol\\phi]$ and the true output target $y_i$:\n",
|
|
"\n",
|
|
" \\begin{align} L[\\boldsymbol\\phi] & = \\sum_{i=1}^{I} \\bigl(\\textrm{f}[x_{i}, \\boldsymbol\\phi]-y_{i}\\bigr)^{2} \\\\ &= \\sum_{i=1}^{I} \\bigl(\\phi_{0}+\\phi_{1}x_i-y_{i}\\bigr)^{2} \\tag{1.2}\\end{align}"
|
|
],
|
|
"metadata": {
|
|
"id": "cG5kwmmPybZK"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Function to calculate the loss\n",
|
|
"def compute_loss(x,y,f,phi0,phi1):\n",
|
|
"\n",
|
|
" signed_distance = f(x,phi0,phi1)-y\n",
|
|
" loss = np.sum(signed_distance * signed_distance)\n",
|
|
"\n",
|
|
" return loss"
|
|
],
|
|
"metadata": {
|
|
"id": "I1vBlFMAyfzp"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Fit the model\n",
|
|
"\n",
|
|
"We'll fit the model using a version of coordinate descent. We first choose a step size $\\alpha$ and then we alternate between updating the intercept parameter $\\phi_0$ and the slope parameter $\\phi_1$. \n",
|
|
"\n",
|
|
"1. Compare the loss for models with $[\\phi_0, \\phi_1]$, $[\\phi_0+\\alpha, \\phi_1]$, and $[\\phi_0-\\alpha, \\phi_1]$. Update the parameters according to the set that have the minimum loss.\n",
|
|
"\n",
|
|
"2. Compare the loss for models with $[\\phi_0, \\phi_1]$, $[\\phi_0,\\phi_1+\\alpha]$, and $[\\phi_0, \\phi_1-\\alpha]$.\n",
|
|
"\n",
|
|
"We'll alternate these two steps until we cannot improve any further."
|
|
],
|
|
"metadata": {
|
|
"id": "4BrOiVY0zTY4"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Utility function for plotting the three models at each stage\n",
|
|
"def plot(fig,ax, x,y, f, phi0_1, phi1_1, phi0_2, phi1_2, phi0_3, phi1_3, loss1, loss2, loss3):\n",
|
|
" x_plot = np.linspace(0,2,100)\n",
|
|
"\n",
|
|
" # Clear previous drawing on these axes\n",
|
|
" ax.clear()\n",
|
|
" # Plotting code\n",
|
|
" ax.plot(x,y,'bo')\n",
|
|
" ax.plot(x_plot,f(x_plot, phi0_1, phi1_1), 'r-')\n",
|
|
" ax.plot(x_plot,f(x_plot, phi0_2, phi1_2), 'g-')\n",
|
|
" ax.plot(x_plot,f(x_plot, phi0_3, phi1_3), 'b-')\n",
|
|
" ax.set_xlim(0,2)\n",
|
|
" ax.set_ylim(0,2)\n",
|
|
" ax.set_title('Losses: {:.2f}(red), {:.2f} (green), {:.2f} (blue)'.format(loss1, loss2, loss3))\n",
|
|
" ax.set_aspect('equal', adjustable='box')\n",
|
|
"\n",
|
|
" # Show the figure and wait 0.1 sec\n",
|
|
" display.display(fig)\n",
|
|
" time.sleep(0.1)\n",
|
|
" display.clear_output(wait=True)\n"
|
|
],
|
|
"metadata": {
|
|
"id": "UbhOL6ob6m6Y"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Main fitting algorithm\n",
|
|
"def fit_model(x,y,f,compute_loss,phi0_init, phi1_init, alpha, n_iter):\n",
|
|
"\n",
|
|
" # Create figure to display results\n",
|
|
" fig,ax = plt.subplots()\n",
|
|
"\n",
|
|
" # These two variables to store the evolution of the parameters\n",
|
|
" phi0_progress = np.zeros(n_iter)\n",
|
|
" phi1_progress = np.zeros(n_iter)\n",
|
|
"\n",
|
|
" # Initialize the history with the provided values\n",
|
|
" phi0_progress[0] = phi0_init\n",
|
|
" phi1_progress[0] = phi1_init\n",
|
|
"\n",
|
|
" # Main iteration loop\n",
|
|
" for c_iter in range(1, n_iter):\n",
|
|
" # TODO Choose parameters for first model [phi0, phi1]\n",
|
|
" # REPLACE THIS CODE\n",
|
|
" phi0_1 = 0\n",
|
|
" phi1_1 = 0\n",
|
|
"\n",
|
|
" # Change the intercept phi0 every other iteration\n",
|
|
" if (c_iter%2==0):\n",
|
|
" # TODO Choose parameters for second model [phi_0+alpha, phi1]\n",
|
|
" # REPLACE THIS CODE\n",
|
|
" phi0_2 = 0\n",
|
|
" phi1_2 = 0\n",
|
|
"\n",
|
|
" # TODO Choose parameters for third model [phi_0+alpha, phi1]\n",
|
|
" # REPLACE THIS CODE\n",
|
|
" phi0_3 = 0\n",
|
|
" phi1_3 = 0\n",
|
|
"\n",
|
|
" # Change the slope phi1 every other iteration\n",
|
|
" else:\n",
|
|
" # TODO Choose parameters for second model [phi_0, phi1+alpha]\n",
|
|
" # REPLACE THIS CODE\n",
|
|
" phi0_2 = 0\n",
|
|
" phi1_2 = 0\n",
|
|
"\n",
|
|
" # TODO Choose parameters for third model [phi_0, phi1-alpha]\n",
|
|
" # REPLACE THIS CODE\n",
|
|
" phi0_3 = 0\n",
|
|
" phi1_3 = 0\n",
|
|
"\n",
|
|
" # TODO Compute the loss for the three models\n",
|
|
" # REPLACE THIS CODE\n",
|
|
" loss1 = 0\n",
|
|
" loss2 = 0\n",
|
|
" loss3 = 0\n",
|
|
"\n",
|
|
" # TODO Set the parameters to the whichever model has the lowest loss\n",
|
|
" # REPLACE THIS CODE\n",
|
|
" phi0_progress[c_iter] = 0\n",
|
|
" phi1_progress[c_iter] = 0\n",
|
|
"\n",
|
|
" # Plot the progress\n",
|
|
" plot(fig, ax, x,y, f, phi0_1, phi1_1, phi0_2, phi1_2, phi0_3, phi1_3, loss1, loss2, loss3)\n",
|
|
"\n",
|
|
" return phi0_progress, phi1_progress"
|
|
],
|
|
"metadata": {
|
|
"id": "VaonEi8gzf3z"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Run the fitting algorithm\n",
|
|
"phi0_progress, phi1_progress = fit_model(x,y,f,compute_loss,phi0_init=1.35, phi1_init=-0.55, alpha=0.125, n_iter=20)"
|
|
],
|
|
"metadata": {
|
|
"id": "STyOoYYv9Ddz"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Helper code to do the drawing\n",
|
|
"def draw_loss_function(compute_loss, f, x_in, y_in, phi0_progress, phi1_progress):\n",
|
|
" # Define pretty colormap\n",
|
|
" my_colormap_vals_hex =('2a0902', '2b0a03', '2c0b04', '2d0c05', '2e0c06', '2f0d07', '300d08', '310e09', '320f0a', '330f0b', '34100b', '35110c', '36110d', '37120e', '38120f', '39130f', '3a1410', '3b1411', '3c1511', '3d1612', '3e1613', '3f1713', '401714', '411814', '421915', '431915', '451a16', '461b16', '471b17', '481c17', '491d18', '4a1d18', '4b1e19', '4c1f19', '4d1f1a', '4e201b', '50211b', '51211c', '52221c', '53231d', '54231d', '55241e', '56251e', '57261f', '58261f', '592720', '5b2821', '5c2821', '5d2922', '5e2a22', '5f2b23', '602b23', '612c24', '622d25', '632e25', '652e26', '662f26', '673027', '683027', '693128', '6a3229', '6b3329', '6c342a', '6d342a', '6f352b', '70362c', '71372c', '72372d', '73382e', '74392e', '753a2f', '763a2f', '773b30', '783c31', '7a3d31', '7b3e32', '7c3e33', '7d3f33', '7e4034', '7f4134', '804235', '814236', '824336', '834437', '854538', '864638', '874739', '88473a', '89483a', '8a493b', '8b4a3c', '8c4b3c', '8d4c3d', '8e4c3e', '8f4d3f', '904e3f', '924f40', '935041', '945141', '955242', '965343', '975343', '985444', '995545', '9a5646', '9b5746', '9c5847', '9d5948', '9e5a49', '9f5a49', 'a05b4a', 'a15c4b', 'a35d4b', 'a45e4c', 'a55f4d', 'a6604e', 'a7614e', 'a8624f', 'a96350', 'aa6451', 'ab6552', 'ac6552', 'ad6653', 'ae6754', 'af6855', 'b06955', 'b16a56', 'b26b57', 'b36c58', 'b46d59', 'b56e59', 'b66f5a', 'b7705b', 'b8715c', 'b9725d', 'ba735d', 'bb745e', 'bc755f', 'bd7660', 'be7761', 'bf7862', 'c07962', 'c17a63', 'c27b64', 'c27c65', 'c37d66', 'c47e67', 'c57f68', 'c68068', 'c78169', 'c8826a', 'c9836b', 'ca846c', 'cb856d', 'cc866e', 'cd876f', 'ce886f', 'ce8970', 'cf8a71', 'd08b72', 'd18c73', 'd28d74', 'd38e75', 'd48f76', 'd59077', 'd59178', 'd69279', 'd7937a', 'd8957b', 'd9967b', 'da977c', 'da987d', 'db997e', 'dc9a7f', 'dd9b80', 'de9c81', 'de9d82', 'df9e83', 'e09f84', 'e1a185', 'e2a286', 'e2a387', 'e3a488', 'e4a589', 'e5a68a', 'e5a78b', 'e6a88c', 'e7aa8d', 'e7ab8e', 'e8ac8f', 'e9ad90', 'eaae91', 'eaaf92', 'ebb093', 'ecb295', 'ecb396', 'edb497', 'eeb598', 'eeb699', 'efb79a', 'efb99b', 'f0ba9c', 'f1bb9d', 'f1bc9e', 'f2bd9f', 'f2bfa1', 'f3c0a2', 'f3c1a3', 'f4c2a4', 'f5c3a5', 'f5c5a6', 'f6c6a7', 'f6c7a8', 'f7c8aa', 'f7c9ab', 'f8cbac', 'f8ccad', 'f8cdae', 'f9ceb0', 'f9d0b1', 'fad1b2', 'fad2b3', 'fbd3b4', 'fbd5b6', 'fbd6b7', 'fcd7b8', 'fcd8b9', 'fcdaba', 'fddbbc', 'fddcbd', 'fddebe', 'fddfbf', 'fee0c1', 'fee1c2', 'fee3c3', 'fee4c5', 'ffe5c6', 'ffe7c7', 'ffe8c9', 'ffe9ca', 'ffebcb', 'ffeccd', 'ffedce', 'ffefcf', 'fff0d1', 'fff2d2', 'fff3d3', 'fff4d5', 'fff6d6', 'fff7d8', 'fff8d9', 'fffada', 'fffbdc', 'fffcdd', 'fffedf', 'ffffe0')\n",
|
|
" my_colormap_vals_dec = np.array([int(element,base=16) for element in my_colormap_vals_hex])\n",
|
|
" r = np.floor(my_colormap_vals_dec/(256*256))\n",
|
|
" g = np.floor((my_colormap_vals_dec - r *256 *256)/256)\n",
|
|
" b = np.floor(my_colormap_vals_dec - r * 256 *256 - g * 256)\n",
|
|
" my_colormap = ListedColormap(np.vstack((r,g,b)).transpose()/255.0)\n",
|
|
"\n",
|
|
" # Make grid of intercept/slope values to plot\n",
|
|
" intercepts_mesh, slopes_mesh = np.meshgrid(np.arange(0.0,2.0,0.02), np.arange(-1.0,1.0,0.002))\n",
|
|
" loss_mesh = np.zeros_like(slopes_mesh)\n",
|
|
"\n",
|
|
" # Compute loss for every set of parameters\n",
|
|
" for idslope, slope in np.ndenumerate(slopes_mesh):\n",
|
|
" loss_mesh[idslope] = compute_loss(x_in, y_in, f, intercepts_mesh[idslope], slope)\n",
|
|
"\n",
|
|
" fig,ax = plt.subplots()\n",
|
|
" fig.set_size_inches(6,6)\n",
|
|
" ax.contourf(intercepts_mesh,slopes_mesh,loss_mesh,256,cmap=my_colormap)\n",
|
|
" ax.contour(intercepts_mesh,slopes_mesh,loss_mesh,40,colors=['#80808080'])\n",
|
|
" ax.plot(phi0_progress, phi1_progress, 'o-', color='#7fe7dc')\n",
|
|
"\n",
|
|
" ax.set_ylim([1,-1])\n",
|
|
" ax.set_xlabel('Intercept, $\\phi_0$')\n",
|
|
" ax.set_ylabel('Slope, $\\phi_1$')\n",
|
|
" ax.set_aspect('equal')\n",
|
|
" plt.show()"
|
|
],
|
|
"metadata": {
|
|
"id": "Q8DEVnj992AW"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"draw_loss_function(compute_loss, f, x, y, phi0_progress, phi1_progress)"
|
|
],
|
|
"metadata": {
|
|
"id": "RMi7WItB-I05"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |