Files
udlbook/Notebooks/Chap03/3_1_Shallow_Networks_I.ipynb
2023-12-24 11:16:55 -05:00

366 lines
45 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPBNztJrxnUt1ELWfm1Awa3",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap03/3_1_Shallow_Networks_I.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Notebook 3.1 -- Shallow neural networks I**\n",
"\n",
"The purpose of this notebook is to gain some familiarity with shallow neural networks with 1D inputs. It works through an example similar to figure 3.3 and experiments with different activation functions. <br>\n",
"\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and write code to complete the functions. There are also questions interspersed in the text.\n",
"\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
],
"metadata": {
"id": "1Z6LB4Ybn1oN"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hAM55ZjSncOk"
},
"outputs": [],
"source": [
"# Imports math library\n",
"import numpy as np\n",
"# Imports plotting library\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"source": [
"Let's first construct the shallow neural network with one input, three hidden units, and one output described in section 3.1 of the book."
],
"metadata": {
"id": "wQDy9UzXpnf5"
}
},
{
"cell_type": "code",
"source": [
"# Define the Rectified Linear Unit (ReLU) function\n",
"def ReLU(preactivation):\n",
" # TODO write code to implement the ReLU and compute the activation at the\n",
" # hidden unit from the preactivation\n",
" # This should work on every element of the ndarray \"preactivation\" at once\n",
" # One way to do this is with the ndarray \"clip\" function\n",
" # https://numpy.org/doc/stable/reference/generated/numpy.ndarray.clip.html\n",
" activation = np.zeros_like(preactivation);\n",
"\n",
" return activation"
],
"metadata": {
"id": "OT7h7sSwpkrt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Make an array of inputs\n",
"z = np.arange(-5,5,0.1)\n",
"RelU_z = ReLU(z)\n",
"\n",
"# Plot the ReLU function\n",
"fig, ax = plt.subplots()\n",
"ax.plot(z,RelU_z,'r-')\n",
"ax.set_xlim([-5,5]);ax.set_ylim([-5,5])\n",
"ax.set_xlabel('z'); ax.set_ylabel('ReLU[z]')\n",
"ax.set_aspect('equal')\n",
"plt.show()"
],
"metadata": {
"id": "okwJmSw9pVNF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define a shallow neural network with, one input, one output, and three hidden units\n",
"def shallow_1_1_3(x, activation_fn, phi_0,phi_1,phi_2,phi_3, theta_10, theta_11, theta_20, theta_21, theta_30, theta_31):\n",
" # TODO Replace the lines below to compute the three initial lines\n",
" # (figure 3.3a-c) from the theta parameters. These are the preactivations\n",
" pre_1 = np.zeros_like(x)\n",
" pre_2 = np.zeros_like(x)\n",
" pre_3 = np.zeros_like(x)\n",
"\n",
" # Pass these through the ReLU function to compute the activations as in\n",
" # figure 3.3 d-f\n",
" act_1 = activation_fn(pre_1)\n",
" act_2 = activation_fn(pre_2)\n",
" act_3 = activation_fn(pre_3)\n",
"\n",
" # TODO Replace the code below to weight the activations using phi1, phi2 and phi3\n",
" # To create the equivalent of figure 3.3 g-i\n",
" w_act_1 = np.zeros_like(x)\n",
" w_act_2 = np.zeros_like(x)\n",
" w_act_3 = np.zeros_like(x)\n",
"\n",
" # TODO Replace the code below to combining the weighted activations and add\n",
" # phi_0 to create the output as in figure 3.3 j\n",
" y = np.zeros_like(x)\n",
"\n",
" # Return everything we have calculated\n",
" return y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3"
],
"metadata": {
"id": "epk68ZCBu7uJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Plot the shallow neural network. We'll assume input in is range [0,1] and output [-1,1]\n",
"# If the plot_all flag is set to true, then we'll plot all the intermediate stages as in Figure 3.3\n",
"def plot_neural(x, y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3, plot_all=False, x_data=None, y_data=None):\n",
"\n",
" # Plot intermediate plots if flag set\n",
" if plot_all:\n",
" fig, ax = plt.subplots(3,3)\n",
" fig.set_size_inches(8.5, 8.5)\n",
" fig.tight_layout(pad=3.0)\n",
" ax[0,0].plot(x,pre_1,'r-'); ax[0,0].set_ylabel('Preactivation')\n",
" ax[0,1].plot(x,pre_2,'b-'); ax[0,1].set_ylabel('Preactivation')\n",
" ax[0,2].plot(x,pre_3,'g-'); ax[0,2].set_ylabel('Preactivation')\n",
" ax[1,0].plot(x,act_1,'r-'); ax[1,0].set_ylabel('Activation')\n",
" ax[1,1].plot(x,act_2,'b-'); ax[1,1].set_ylabel('Activation')\n",
" ax[1,2].plot(x,act_3,'g-'); ax[1,2].set_ylabel('Activation')\n",
" ax[2,0].plot(x,w_act_1,'r-'); ax[2,0].set_ylabel('Weighted Act')\n",
" ax[2,1].plot(x,w_act_2,'b-'); ax[2,1].set_ylabel('Weighted Act')\n",
" ax[2,2].plot(x,w_act_3,'g-'); ax[2,2].set_ylabel('Weighted Act')\n",
"\n",
" for plot_y in range(3):\n",
" for plot_x in range(3):\n",
" ax[plot_y,plot_x].set_xlim([0,1]);ax[plot_x,plot_y].set_ylim([-1,1])\n",
" ax[plot_y,plot_x].set_aspect(0.5)\n",
" ax[2,plot_y].set_xlabel('Input, $x$');\n",
" plt.show()\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.plot(x,y)\n",
" ax.set_xlabel('Input, $x$'); ax.set_ylabel('Output, $y$')\n",
" ax.set_xlim([0,1]);ax.set_ylim([-1,1])\n",
" ax.set_aspect(0.5)\n",
" if x_data is not None:\n",
" ax.plot(x_data, y_data, 'mo')\n",
" for i in range(len(x_data)):\n",
" ax.plot(x_data[i], y_data[i],)\n",
" plt.show()"
],
"metadata": {
"id": "CAr7n1lixuhQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Now lets define some parameters and run the neural network\n",
"theta_10 = 0.3 ; theta_11 = -1.0\n",
"theta_20 = -1.0 ; theta_21 = 2.0\n",
"theta_30 = -0.5 ; theta_31 = 0.65\n",
"phi_0 = -0.3; phi_1 = 2.0; phi_2 = -1.0; phi_3 = 7.0\n",
"\n",
"# Define a range of input values\n",
"x = np.arange(0,1,0.01)\n",
"\n",
"# We run the neural network for each of these input values\n",
"y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3 = \\\n",
" shallow_1_1_3(x, ReLU, phi_0,phi_1,phi_2,phi_3, theta_10, theta_11, theta_20, theta_21, theta_30, theta_31)\n",
"# And then plot it\n",
"plot_neural(x, y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3, plot_all=True)"
],
"metadata": {
"id": "SzIVdp9U-JWb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"If your code is correct, then the final output should look like this:\n",
"<img src=\"data:image/svg+xml;base64,<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<!-- Created with matplotlib (https://matplotlib.org/) -->
<svg height="288pt" version="1.1" viewBox="0 0 432 288" width="432pt" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
 <defs>
  <style type="text/css">
*{stroke-linecap:butt;stroke-linejoin:round;}
  </style>
 </defs>
 <g id="figure_1">
  <g id="patch_1">
   <path d="M 0 288 
L 432 288 
L 432 0 
L 0 0 
z
" style="fill:#ffffff;"/>
  </g>
  <g id="axes_1">
   <g id="patch_2">
    <path d="M 112.68 252 
L 330.12 252 
L 330.12 34.56 
L 112.68 34.56 
z
" style="fill:#ffffff;"/>
   </g>
   <g id="matplotlib.axis_1">
    <g id="xtick_1">
     <g id="line2d_1">
      <defs>
       <path d="M 0 0 
L 0 3.5 
" id="m1daa3bfc50" style="stroke:#000000;stroke-width:0.8;"/>
      </defs>
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#m1daa3bfc50" y="252"/>
      </g>
     </g>
     <g id="text_1">
      <!-- 0.0 -->
      <defs>
       <path d="M 31.78125 66.40625 
Q 24.171875 66.40625 20.328125 58.90625 
Q 16.5 51.421875 16.5 36.375 
Q 16.5 21.390625 20.328125 13.890625 
Q 24.171875 6.390625 31.78125 6.390625 
Q 39.453125 6.390625 43.28125 13.890625 
Q 47.125 21.390625 47.125 36.375 
Q 47.125 51.421875 43.28125 58.90625 
Q 39.453125 66.40625 31.78125 66.40625 
z
M 31.78125 74.21875 
Q 44.046875 74.21875 50.515625 64.515625 
Q 56.984375 54.828125 56.984375 36.375 
Q 56.984375 17.96875 50.515625 8.265625 
Q 44.046875 -1.421875 31.78125 -1.421875 
Q 19.53125 -1.421875 13.0625 8.265625 
Q 6.59375 17.96875 6.59375 36.375 
Q 6.59375 54.828125 13.0625 64.515625 
Q 19.53125 74.21875 31.78125 74.21875 
z
" id="DejaVuSans-48"/>
       <path d="M 10.6875 12.40625 
L 21 12.40625 
L 21 0 
L 10.6875 0 
z
" id="DejaVuSans-46"/>
      </defs>
      <g transform="translate(104.728438 266.598437)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-48"/>
      </g>
     </g>
    </g>
    <g id="xtick_2">
     <g id="line2d_2">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="156.168" xlink:href="#m1daa3bfc50" y="252"/>
      </g>
     </g>
     <g id="text_2">
      <!-- 0.2 -->
      <defs>
       <path d="M 19.1875 8.296875 
L 53.609375 8.296875 
L 53.609375 0 
L 7.328125 0 
L 7.328125 8.296875 
Q 12.9375 14.109375 22.625 23.890625 
Q 32.328125 33.6875 34.8125 36.53125 
Q 39.546875 41.84375 41.421875 45.53125 
Q 43.3125 49.21875 43.3125 52.78125 
Q 43.3125 58.59375 39.234375 62.25 
Q 35.15625 65.921875 28.609375 65.921875 
Q 23.96875 65.921875 18.8125 64.3125 
Q 13.671875 62.703125 7.8125 59.421875 
L 7.8125 69.390625 
Q 13.765625 71.78125 18.9375 73 
Q 24.125 74.21875 28.421875 74.21875 
Q 39.75 74.21875 46.484375 68.546875 
Q 53.21875 62.890625 53.21875 53.421875 
Q 53.21875 48.921875 51.53125 44.890625 
Q 49.859375 40.875 45.40625 35.40625 
Q 44.1875 33.984375 37.640625 27.21875 
Q 31.109375 20.453125 19.1875 8.296875 
z
" id="DejaVuSans-50"/>
      </defs>
      <g transform="translate(148.216438 266.598437)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-50"/>
      </g>
     </g>
    </g>
    <g id="xtick_3">
     <g id="line2d_3">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="199.656" xlink:href="#m1daa3bfc50" y="252"/>
      </g>
     </g>
     <g id="text_3">
      <!-- 0.4 -->
      <defs>
       <path d="M 37.796875 64.3125 
L 12.890625 25.390625 
L 37.796875 25.390625 
z
M 35.203125 72.90625 
L 47.609375 72.90625 
L 47.609375 25.390625 
L 58.015625 25.390625 
L 58.015625 17.1875 
L 47.609375 17.1875 
L 47.609375 0 
L 37.796875 0 
L 37.796875 17.1875 
L 4.890625 17.1875 
L 4.890625 26.703125 
z
" id="DejaVuSans-52"/>
      </defs>
      <g transform="translate(191.704438 266.598437)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-52"/>
      </g>
     </g>
    </g>
    <g id="xtick_4">
     <g id="line2d_4">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="243.144" xlink:href="#m1daa3bfc50" y="252"/>
      </g>
     </g>
     <g id="text_4">
      <!-- 0.6 -->
      <defs>
       <path d="M 33.015625 40.375 
Q 26.375 40.375 22.484375 35.828125 
Q 18.609375 31.296875 18.609375 23.390625 
Q 18.609375 15.53125 22.484375 10.953125 
Q 26.375 6.390625 33.015625 6.390625 
Q 39.65625 6.390625 43.53125 10.953125 
Q 47.40625 15.53125 47.40625 23.390625 
Q 47.40625 31.296875 43.53125 35.828125 
Q 39.65625 40.375 33.015625 40.375 
z
M 52.59375 71.296875 
L 52.59375 62.3125 
Q 48.875 64.0625 45.09375 64.984375 
Q 41.3125 65.921875 37.59375 65.921875 
Q 27.828125 65.921875 22.671875 59.328125 
Q 17.53125 52.734375 16.796875 39.40625 
Q 19.671875 43.65625 24.015625 45.921875 
Q 28.375 48.1875 33.59375 48.1875 
Q 44.578125 48.1875 50.953125 41.515625 
Q 57.328125 34.859375 57.328125 23.390625 
Q 57.328125 12.15625 50.6875 5.359375 
Q 44.046875 -1.421875 33.015625 -1.421875 
Q 20.359375 -1.421875 13.671875 8.265625 
Q 6.984375 17.96875 6.984375 36.375 
Q 6.984375 53.65625 15.1875 63.9375 
Q 23.390625 74.21875 37.203125 74.21875 
Q 40.921875 74.21875 44.703125 73.484375 
Q 48.484375 72.75 52.59375 71.296875 
z
" id="DejaVuSans-54"/>
      </defs>
      <g transform="translate(235.192438 266.598437)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-54"/>
      </g>
     </g>
    </g>
    <g id="xtick_5">
     <g id="line2d_5">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="286.632" xlink:href="#m1daa3bfc50" y="252"/>
      </g>
     </g>
     <g id="text_5">
      <!-- 0.8 -->
      <defs>
       <path d="M 31.78125 34.625 
Q 24.75 34.625 20.71875 30.859375 
Q 16.703125 27.09375 16.703125 20.515625 
Q 16.703125 13.921875 20.71875 10.15625 
Q 24.75 6.390625 31.78125 6.390625 
Q 38.8125 6.390625 42.859375 10.171875 
Q 46.921875 13.96875 46.921875 20.515625 
Q 46.921875 27.09375 42.890625 30.859375 
Q 38.875 34.625 31.78125 34.625 
z
M 21.921875 38.8125 
Q 15.578125 40.375 12.03125 44.71875 
Q 8.5 49.078125 8.5 55.328125 
Q 8.5 64.0625 14.71875 69.140625 
Q 20.953125 74.21875 31.78125 74.21875 
Q 42.671875 74.21875 48.875 69.140625 
Q 55.078125 64.0625 55.078125 55.328125 
Q 55.078125 49.078125 51.53125 44.71875 
Q 48 40.375 41.703125 38.8125 
Q 48.828125 37.15625 52.796875 32.3125 
Q 56.78125 27.484375 56.78125 20.515625 
Q 56.78125 9.90625 50.3125 4.234375 
Q 43.84375 -1.421875 31.78125 -1.421875 
Q 19.734375 -1.421875 13.25 4.234375 
Q 6.78125 9.90625 6.78125 20.515625 
Q 6.78125 27.484375 10.78125 32.3125 
Q 14.796875 37.15625 21.921875 38.8125 
z
M 18.3125 54.390625 
Q 18.3125 48.734375 21.84375 45.5625 
Q 25.390625 42.390625 31.78125 42.390625 
Q 38.140625 42.390625 41.71875 45.5625 
Q 45.3125 48.734375 45.3125 54.390625 
Q 45.3125 60.0625 41.71875 63.234375 
Q 38.140625 66.40625 31.78125 66.40625 
Q 25.390625 66.40625 21.84375 63.234375 
Q 18.3125 60.0625 18.3125 54.390625 
z
" id="DejaVuSans-56"/>
      </defs>
      <g transform="translate(278.680437 266.598437)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-56"/>
      </g>
     </g>
    </g>
    <g id="xtick_6">
     <g id="line2d_6">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="330.12" xlink:href="#m1daa3bfc50" y="252"/>
      </g>
     </g>
     <g id="text_6">
      <!-- 1.0 -->
      <defs>
       <path d="M 12.40625 8.296875 
L 28.515625 8.296875 
L 28.515625 63.921875 
L 10.984375 60.40625 
L 10.984375 69.390625 
L 28.421875 72.90625 
L 38.28125 72.90625 
L 38.28125 8.296875 
L 54.390625 8.296875 
L 54.390625 0 
L 12.40625 0 
z
" id="DejaVuSans-49"/>
      </defs>
      <g transform="translate(322.168437 266.598437)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-49"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-48"/>
      </g>
     </g>
    </g>
    <g id="text_7">
     <!-- Input, $y$ -->
     <defs>
      <path d="M 9.8125 72.90625 
L 19.671875 72.90625 
L 19.671875 0 
L 9.8125 0 
z
" id="DejaVuSans-73"/>
      <path d="M 54.890625 33.015625 
L 54.890625 0 
L 45.90625 0 
L 45.90625 32.71875 
Q 45.90625 40.484375 42.875 44.328125 
Q 39.84375 48.1875 33.796875 48.1875 
Q 26.515625 48.1875 22.3125 43.546875 
Q 18.109375 38.921875 18.109375 30.90625 
L 18.109375 0 
L 9.078125 0 
L 9.078125 54.6875 
L 18.109375 54.6875 
L 18.109375 46.1875 
Q 21.34375 51.125 25.703125 53.5625 
Q 30.078125 56 35.796875 56 
Q 45.21875 56 50.046875 50.171875 
Q 54.890625 44.34375 54.890625 33.015625 
z
" id="DejaVuSans-110"/>
      <path d="M 18.109375 8.203125 
L 18.109375 -20.796875 
L 9.078125 -20.796875 
L 9.078125 54.6875 
L 18.109375 54.6875 
L 18.109375 46.390625 
Q 20.953125 51.265625 25.265625 53.625 
Q 29.59375 56 35.59375 56 
Q 45.5625 56 51.78125 48.09375 
Q 58.015625 40.1875 58.015625 27.296875 
Q 58.015625 14.40625 51.78125 6.484375 
Q 45.5625 -1.421875 35.59375 -1.421875 
Q 29.59375 -1.421875 25.265625 0.953125 
Q 20.953125 3.328125 18.109375 8.203125 
z
M 48.6875 27.296875 
Q 48.6875 37.203125 44.609375 42.84375 
Q 40.53125 48.484375 33.40625 48.484375 
Q 26.265625 48.484375 22.1875 42.84375 
Q 18.109375 37.203125 18.109375 27.296875 
Q 18.109375 17.390625 22.1875 11.75 
Q 26.265625 6.109375 33.40625 6.109375 
Q 40.53125 6.109375 44.609375 11.75 
Q 48.6875 17.390625 48.6875 27.296875 
z
" id="DejaVuSans-112"/>
      <path d="M 8.5 21.578125 
L 8.5 54.6875 
L 17.484375 54.6875 
L 17.484375 21.921875 
Q 17.484375 14.15625 20.5 10.265625 
Q 23.53125 6.390625 29.59375 6.390625 
Q 36.859375 6.390625 41.078125 11.03125 
Q 45.3125 15.671875 45.3125 23.6875 
L 45.3125 54.6875 
L 54.296875 54.6875 
L 54.296875 0 
L 45.3125 0 
L 45.3125 8.40625 
Q 42.046875 3.421875 37.71875 1 
Q 33.40625 -1.421875 27.6875 -1.421875 
Q 18.265625 -1.421875 13.375 4.4375 
Q 8.5 10.296875 8.5 21.578125 
z
M 31.109375 56 
z
" id="DejaVuSans-117"/>
      <path d="M 18.3125 70.21875 
L 18.3125 54.6875 
L 36.8125 54.6875 
L 36.8125 47.703125 
L 18.3125 47.703125 
L 18.3125 18.015625 
Q 18.3125 11.328125 20.140625 9.421875 
Q 21.96875 7.515625 27.59375 7.515625 
L 36.8125 7.515625 
L 36.8125 0 
L 27.59375 0 
Q 17.1875 0 13.234375 3.875 
Q 9.28125 7.765625 9.28125 18.015625 
L 9.28125 47.703125 
L 2.6875 47.703125 
L 2.6875 54.6875 
L 9.28125 54.6875 
L 9.28125 70.21875 
z
" id="DejaVuSans-116"/>
      <path d="M 11.71875 12.40625 
L 22.015625 12.40625 
L 22.015625 4 
L 14.015625 -11.625 
L 7.71875 -11.625 
L 11.71875 4 
z
" id="DejaVuSans-44"/>
      <path id="DejaVuSans-32"/>
      <path d="M 24.8125 -5.078125 
Q 18.5625 -15.578125 14.625 -18.1875 
Q 10.6875 -20.796875 4.59375 -20.796875 
L -2.484375 -20.796875 
L -0.984375 -13.28125 
L 4.203125 -13.28125 
Q 7.953125 -13.28125 10.59375 -11.234375 
Q 13.234375 -9.1875 16.5 -3.21875 
L 19.28125 2 
L 7.171875 54.6875 
L 16.703125 54.6875 
L 25.78125 12.796875 
L 50.875 54.6875 
L 60.296875 54.6875 
z
" id="DejaVuSans-Oblique-121"/>
     </defs>
     <g transform="translate(202.3 280.25625)scale(0.1 -0.1)">
      <use transform="translate(0 0.09375)" xlink:href="#DejaVuSans-73"/>
      <use transform="translate(29.492188 0.09375)" xlink:href="#DejaVuSans-110"/>
      <use transform="translate(92.871094 0.09375)" xlink:href="#DejaVuSans-112"/>
      <use transform="translate(156.347656 0.09375)" xlink:href="#DejaVuSans-117"/>
      <use transform="translate(219.726562 0.09375)" xlink:href="#DejaVuSans-116"/>
      <use transform="translate(258.935547 0.09375)" xlink:href="#DejaVuSans-44"/>
      <use transform="translate(290.722656 0.09375)" xlink:href="#DejaVuSans-32"/>
      <use transform="translate(322.509766 0.09375)" xlink:href="#DejaVuSans-Oblique-121"/>
     </g>
    </g>
   </g>
   <g id="matplotlib.axis_2">
    <g id="ytick_1">
     <g id="line2d_7">
      <defs>
       <path d="M 0 0 
L -3.5 0 
" id="ma400ee42b1" style="stroke:#000000;stroke-width:0.8;"/>
      </defs>
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="252"/>
      </g>
     </g>
     <g id="text_8">
      <!-- −1.00 -->
      <defs>
       <path d="M 10.59375 35.5 
L 73.1875 35.5 
L 73.1875 27.203125 
L 10.59375 27.203125 
z
" id="DejaVuSans-8722"/>
      </defs>
      <g transform="translate(75.034688 255.799219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-8722"/>
       <use x="83.789062" xlink:href="#DejaVuSans-49"/>
       <use x="147.412109" xlink:href="#DejaVuSans-46"/>
       <use x="179.199219" xlink:href="#DejaVuSans-48"/>
       <use x="242.822266" xlink:href="#DejaVuSans-48"/>
      </g>
     </g>
    </g>
    <g id="ytick_2">
     <g id="line2d_8">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="224.82"/>
      </g>
     </g>
     <g id="text_9">
      <!-- −0.75 -->
      <defs>
       <path d="M 8.203125 72.90625 
L 55.078125 72.90625 
L 55.078125 68.703125 
L 28.609375 0 
L 18.3125 0 
L 43.21875 64.59375 
L 8.203125 64.59375 
z
" id="DejaVuSans-55"/>
       <path d="M 10.796875 72.90625 
L 49.515625 72.90625 
L 49.515625 64.59375 
L 19.828125 64.59375 
L 19.828125 46.734375 
Q 21.96875 47.46875 24.109375 47.828125 
Q 26.265625 48.1875 28.421875 48.1875 
Q 40.625 48.1875 47.75 41.5 
Q 54.890625 34.8125 54.890625 23.390625 
Q 54.890625 11.625 47.5625 5.09375 
Q 40.234375 -1.421875 26.90625 -1.421875 
Q 22.3125 -1.421875 17.546875 -0.640625 
Q 12.796875 0.140625 7.71875 1.703125 
L 7.71875 11.625 
Q 12.109375 9.234375 16.796875 8.0625 
Q 21.484375 6.890625 26.703125 6.890625 
Q 35.15625 6.890625 40.078125 11.328125 
Q 45.015625 15.765625 45.015625 23.390625 
Q 45.015625 31 40.078125 35.4375 
Q 35.15625 39.890625 26.703125 39.890625 
Q 22.75 39.890625 18.8125 39.015625 
Q 14.890625 38.140625 10.796875 36.28125 
z
" id="DejaVuSans-53"/>
      </defs>
      <g transform="translate(75.034688 228.619219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-8722"/>
       <use x="83.789062" xlink:href="#DejaVuSans-48"/>
       <use x="147.412109" xlink:href="#DejaVuSans-46"/>
       <use x="179.199219" xlink:href="#DejaVuSans-55"/>
       <use x="242.822266" xlink:href="#DejaVuSans-53"/>
      </g>
     </g>
    </g>
    <g id="ytick_3">
     <g id="line2d_9">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="197.64"/>
      </g>
     </g>
     <g id="text_10">
      <!-- −0.50 -->
      <g transform="translate(75.034688 201.439219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-8722"/>
       <use x="83.789062" xlink:href="#DejaVuSans-48"/>
       <use x="147.412109" xlink:href="#DejaVuSans-46"/>
       <use x="179.199219" xlink:href="#DejaVuSans-53"/>
       <use x="242.822266" xlink:href="#DejaVuSans-48"/>
      </g>
     </g>
    </g>
    <g id="ytick_4">
     <g id="line2d_10">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="170.46"/>
      </g>
     </g>
     <g id="text_11">
      <!-- −0.25 -->
      <g transform="translate(75.034688 174.259219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-8722"/>
       <use x="83.789062" xlink:href="#DejaVuSans-48"/>
       <use x="147.412109" xlink:href="#DejaVuSans-46"/>
       <use x="179.199219" xlink:href="#DejaVuSans-50"/>
       <use x="242.822266" xlink:href="#DejaVuSans-53"/>
      </g>
     </g>
    </g>
    <g id="ytick_5">
     <g id="line2d_11">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="143.28"/>
      </g>
     </g>
     <g id="text_12">
      <!-- 0.00 -->
      <g transform="translate(83.414375 147.079219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-48"/>
       <use x="159.033203" xlink:href="#DejaVuSans-48"/>
      </g>
     </g>
    </g>
    <g id="ytick_6">
     <g id="line2d_12">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="116.1"/>
      </g>
     </g>
     <g id="text_13">
      <!-- 0.25 -->
      <g transform="translate(83.414375 119.899219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-50"/>
       <use x="159.033203" xlink:href="#DejaVuSans-53"/>
      </g>
     </g>
    </g>
    <g id="ytick_7">
     <g id="line2d_13">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="88.92"/>
      </g>
     </g>
     <g id="text_14">
      <!-- 0.50 -->
      <g transform="translate(83.414375 92.719219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-53"/>
       <use x="159.033203" xlink:href="#DejaVuSans-48"/>
      </g>
     </g>
    </g>
    <g id="ytick_8">
     <g id="line2d_14">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="61.74"/>
      </g>
     </g>
     <g id="text_15">
      <!-- 0.75 -->
      <g transform="translate(83.414375 65.539219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-48"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-55"/>
       <use x="159.033203" xlink:href="#DejaVuSans-53"/>
      </g>
     </g>
    </g>
    <g id="ytick_9">
     <g id="line2d_15">
      <g>
       <use style="stroke:#000000;stroke-width:0.8;" x="112.68" xlink:href="#ma400ee42b1" y="34.56"/>
      </g>
     </g>
     <g id="text_16">
      <!-- 1.00 -->
      <g transform="translate(83.414375 38.359219)scale(0.1 -0.1)">
       <use xlink:href="#DejaVuSans-49"/>
       <use x="63.623047" xlink:href="#DejaVuSans-46"/>
       <use x="95.410156" xlink:href="#DejaVuSans-48"/>
       <use x="159.033203" xlink:href="#DejaVuSans-48"/>
      </g>
     </g>
    </g>
    <g id="text_17">
     <!-- Output, $y$ -->
     <defs>
      <path d="M 39.40625 66.21875 
Q 28.65625 66.21875 22.328125 58.203125 
Q 16.015625 50.203125 16.015625 36.375 
Q 16.015625 22.609375 22.328125 14.59375 
Q 28.65625 6.59375 39.40625 6.59375 
Q 50.140625 6.59375 56.421875 14.59375 
Q 62.703125 22.609375 62.703125 36.375 
Q 62.703125 50.203125 56.421875 58.203125 
Q 50.140625 66.21875 39.40625 66.21875 
z
M 39.40625 74.21875 
Q 54.734375 74.21875 63.90625 63.9375 
Q 73.09375 53.65625 73.09375 36.375 
Q 73.09375 19.140625 63.90625 8.859375 
Q 54.734375 -1.421875 39.40625 -1.421875 
Q 24.03125 -1.421875 14.8125 8.828125 
Q 5.609375 19.09375 5.609375 36.375 
Q 5.609375 53.65625 14.8125 63.9375 
Q 24.03125 74.21875 39.40625 74.21875 
z
" id="DejaVuSans-79"/>
     </defs>
     <g transform="translate(68.934687 166.83)rotate(-90)scale(0.1 -0.1)">
      <use transform="translate(0 0.78125)" xlink:href="#DejaVuSans-79"/>
      <use transform="translate(78.710938 0.78125)" xlink:href="#DejaVuSans-117"/>
      <use transform="translate(142.089844 0.78125)" xlink:href="#DejaVuSans-116"/>
      <use transform="translate(181.298828 0.78125)" xlink:href="#DejaVuSans-112"/>
      <use transform="translate(244.775391 0.78125)" xlink:href="#DejaVuSans-117"/>
      <use transform="translate(308.154297 0.78125)" xlink:href="#DejaVuSans-116"/>
      <use transform="translate(347.363281 0.78125)" xlink:href="#DejaVuSans-44"/>
      <use transform="translate(379.150391 0.78125)" xlink:href="#DejaVuSans-32"/>
      <use transform="translate(410.9375 0.78125)" xlink:href="#DejaVuSans-Oblique-121"/>
     </g>
    </g>
   </g>
   <g id="line2d_16">
    <path clip-path="url(#pd2ba4368fe)" d="M 112.68 110.664 
L 114.8544 112.8384 
L 117.0288 115.0128 
L 119.2032 117.1872 
L 121.3776 119.3616 
L 123.552 121.536 
L 125.7264 123.7104 
L 127.9008 125.8848 
L 130.0752 128.0592 
L 132.2496 130.2336 
L 134.424 132.408 
L 136.5984 134.5824 
L 138.7728 136.7568 
L 140.9472 138.9312 
L 143.1216 141.1056 
L 145.296 143.28 
L 147.4704 145.4544 
L 149.6448 147.6288 
L 151.8192 149.8032 
L 153.9936 151.9776 
L 156.168 154.152 
L 158.3424 156.3264 
L 160.5168 158.5008 
L 162.6912 160.6752 
L 164.8656 162.8496 
L 167.04 165.024 
L 169.2144 167.1984 
L 171.3888 169.3728 
L 173.5632 171.5472 
L 175.7376 173.7216 
L 177.912 175.896 
L 180.0864 175.896 
L 182.2608 175.896 
L 184.4352 175.896 
L 186.6096 175.896 
L 188.784 175.896 
L 190.9584 175.896 
L 193.1328 175.896 
L 195.3072 175.896 
L 197.4816 175.896 
L 199.656 175.896 
L 201.8304 175.896 
L 204.0048 175.896 
L 206.1792 175.896 
L 208.3536 175.896 
L 210.528 175.896 
L 212.7024 175.896 
L 214.8768 175.896 
L 217.0512 175.896 
L 219.2256 175.896 
L 221.4 175.896 
L 223.5744 178.0704 
L 225.7488 180.2448 
L 227.9232 182.4192 
L 230.0976 184.5936 
L 232.272 186.768 
L 234.4464 188.9424 
L 236.6208 191.1168 
L 238.7952 193.2912 
L 240.9696 195.4656 
L 243.144 197.64 
L 245.3184 199.8144 
L 247.4928 201.9888 
L 249.6672 204.1632 
L 251.8416 206.3376 
L 254.016 208.512 
L 256.1904 210.6864 
L 258.3648 212.8608 
L 260.5392 215.0352 
L 262.7136 217.2096 
L 264.888 219.384 
L 267.0624 221.5584 
L 269.2368 223.7328 
L 271.4112 225.9072 
L 273.5856 228.0816 
L 275.76 230.256 
L 277.9344 232.4304 
L 280.1088 234.22428 
L 282.2832 231.45192 
L 284.4576 228.67956 
L 286.632 225.9072 
L 288.8064 223.13484 
L 290.9808 220.36248 
L 293.1552 217.59012 
L 295.3296 214.81776 
L 297.504 212.0454 
L 299.6784 209.27304 
L 301.8528 206.50068 
L 304.0272 203.72832 
L 306.2016 200.95596 
L 308.376 198.1836 
L 310.5504 195.41124 
L 312.7248 192.63888 
L 314.8992 189.86652 
L 317.0736 187.09416 
L 319.248 184.3218 
L 321.4224 181.54944 
L 323.5968 178.77708 
L 325.7712 176.00472 
L 327.9456 173.23236 
" style="fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;"/>
   </g>
   <g id="patch_3">
    <path d="M 112.68 252 
L 112.68 34.56 
" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;"/>
   </g>
   <g id="patch_4">
    <path d="M 330.12 252 
L 330.12 34.56 
" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;"/>
   </g>
   <g id="patch_5">
    <path d="M 112.68 252 
L 330.12 252 
" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;"/>
   </g>
   <g id="patch_6">
    <path d="M 112.68 34.56 
L 330.12 34.56 
" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;"/>
   </g>
  </g>
 </g>
 <defs>
  <clipPath id="pd2ba4368fe">
   <rect height="217.44" width="217.44" x="112.68" y="34.56"/>
  </clipPath>
 </defs>
</svg>
\">"
],
"metadata": {
"id": "T34bszToImKQ"
}
},
{
"cell_type": "markdown",
"source": [
"Now let's play with the parameters to make sure we understand how they work. The original parameters were:\n",
"\n",
"$\\theta_{10} = 0.3$ ; $\\theta_{11} = -1.0$<br>\n",
"$\\theta_{20} = -1.0$ ; $\\theta_{21} = 2.0$<br>\n",
"$\\theta_{30} = -0.5$ ; $\\theta_{31} = 0.65$<br>\n",
"$\\phi_0 = -0.3; \\phi_1 = 2.0; \\phi_2 = -1.0; \\phi_3 = 7.0$"
],
"metadata": {
"id": "jhaBSS8oIWSX"
}
},
{
"cell_type": "code",
"source": [
"# TODO\n",
"# 1. Predict what effect changing phi_0 will have on the network.\n",
"\n",
"# 2. Predict what effect multiplying phi_1, phi_2, phi_3 by 0.5 would have. Check if you are correct\n",
"\n",
"# 3. Predict what effect multiplying phi_1 by -1 will have. Check if you are correct.\n",
"\n",
"# 4. Predict what effect setting theta_20 to -1.2 will have. Check if you are correct.\n",
"\n",
"# 5. Change the parameters so that there are only two \"joints\" (including outside the range of the plot)\n",
"# There are actually three ways to do this. See if you can figure them all out\n",
"\n",
"# 6. With the original parameters, the second line segment is flat (i.e. has slope zero)\n",
"# How could you change theta_10 so that all of the segments have non-zero slopes\n",
"\n",
"# 7. What do you predict would happen if you multiply theta_20 and theta21 by 0.5, and phi_2 by 2.0?\n",
"# Check if you are correct.\n",
"\n",
"# 8. What do you predict would happen if you multiply theta_20 and theta21 by -0.5, and phi_2 by -2.0?\n",
"# Check if you are correct.\n",
"\n",
"theta_10 = 0.3 ; theta_11 = -1.0\n",
"theta_20 = -1.0 ; theta_21 = 2.0\n",
"theta_30 = -0.5 ; theta_31 = 0.65\n",
"phi_0 = -0.3; phi_1 = 2.0; phi_2 = -1.0; phi_3 = 7.0\n",
"\n",
"# Define a range of input values\n",
"x = np.arange(0,1,0.01)\n",
"\n",
"# We run the neural network for each of these input values\n",
"y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3 = \\\n",
" shallow_1_1_3(x, ReLU, phi_0,phi_1,phi_2,phi_3, theta_10, theta_11, theta_20, theta_21, theta_30, theta_31)\n",
"# And then plot it\n",
"plot_neural(x, y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3, plot_all=True)"
],
"metadata": {
"id": "ur4arJ8KAQWe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Least squares loss\n",
"\n",
"Now let's consider fitting the network to data. First we need to define the loss function. We'll use the least squares loss:\n",
"\n",
"\\begin{equation}\n",
"L[\\boldsymbol\\phi] = \\sum_{i=1}^{I}(y_{i}-\\mbox{f}[x_{i},\\boldsymbol\\phi])^2\n",
"\\end{equation}\n",
"\n",
"where $(x_i,y_i)$ is an input/output training pair and $\\mbox{f}[\\bullet,\\boldsymbol\\phi]$ is the neural network with parameters $\\boldsymbol\\phi$. The first term in the brackets is the ground truth output and the second term is the prediction of the model"
],
"metadata": {
"id": "osonHsEqVp2I"
}
},
{
"cell_type": "code",
"source": [
"# Least squares function\n",
"def least_squares_loss(y_train, y_predict):\n",
" # TODO Replace the line below to compute the sum of squared\n",
" # differences between the real values of y and the predicted values from the model f[x_i,phi]\n",
" # (see figure 2.2 of the book)\n",
" # you will need to use the function np.sum\n",
" loss = 0\n",
"\n",
" return loss"
],
"metadata": {
"id": "14d5II-TU46w"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Now lets define some parameters, run the neural network, and compute the loss\n",
"theta_10 = 0.3 ; theta_11 = -1.0\n",
"theta_20 = -1.0 ; theta_21 = 2.0\n",
"theta_30 = -0.5 ; theta_31 = 0.65\n",
"phi_0 = -0.3; phi_1 = 2.0; phi_2 = -1.0; phi_3 = 7.0\n",
"\n",
"# Define a range of input values\n",
"x = np.arange(0,1,0.01)\n",
"\n",
"x_train = np.array([0.09291784,0.46809093,0.93089486,0.67612654,0.73441752,0.86847339,\\\n",
" 0.49873225,0.51083168,0.18343972,0.99380898,0.27840809,0.38028817,\\\n",
" 0.12055708,0.56715537,0.92005746,0.77072270,0.85278176,0.05315950,\\\n",
" 0.87168699,0.58858043])\n",
"y_train = np.array([-0.15934537,0.18195445,0.451270150,0.13921448,0.09366691,0.30567674,\\\n",
" 0.372291170,0.40716968,-0.08131792,0.41187806,0.36943738,0.3994327,\\\n",
" 0.019062570,0.35820410,0.452564960,-0.0183121,0.02957665,-0.24354444, \\\n",
" 0.148038840,0.26824970])\n",
"\n",
"# We run the neural network for each of these input values\n",
"y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3 = \\\n",
" shallow_1_1_3(x, ReLU, phi_0,phi_1,phi_2,phi_3, theta_10, theta_11, theta_20, theta_21, theta_30, theta_31)\n",
"# And then plot it\n",
"plot_neural(x, y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3, plot_all=True, x_data = x_train, y_data = y_train)\n",
"\n",
"# Run the neural network on the training data\n",
"y_predict, *_ = shallow_1_1_3(x_train, ReLU, phi_0,phi_1,phi_2,phi_3, theta_10, theta_11, theta_20, theta_21, theta_30, theta_31)\n",
"\n",
"# Compute the least squares loss and print it out\n",
"loss = least_squares_loss(y_train,y_predict)\n",
"print("Your Loss = %3.3f, True value = 9.385"%(loss))\n",
"\n",
"# TODO. Manipulate the parameters (by hand!) to make the function\n",
"# fit the data better and try to reduce the loss to as small a number\n",
"# as possible. The best that I could do was 0.181\n",
"# Tip... start by manipulating phi_0.\n",
"# It's not that easy, so don't spend too much time on this!\n"
],
"metadata": {
"id": "o6GXjtRubZ2U"
},
"execution_count": null,
"outputs": []
}
]
}