diff --git a/CM20315_Deep2.ipynb b/CM20315_Deep2.ipynb new file mode 100644 index 0000000..b8d1039 --- /dev/null +++ b/CM20315_Deep2.ipynb @@ -0,0 +1,407 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyNg8z5TKbRHrYXsiWCEybnu", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Deep neural networks #2**\n", + "\n", + "In this notebook, we'll investigate converting neural networks to matrix form." + ], + "metadata": { + "id": "MaKn8CFlzN8E" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "8ClURpZQzI6L" + }, + "outputs": [], + "source": [ + "# Imports math library\n", + "import numpy as np\n", + "# Imports plotting library\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "source": [ + "# Define the Rectified Linear Unit (ReLU) function\n", + "def ReLU(preactivation):\n", + " activation = preactivation.clip(0.0)\n", + " return activation" + ], + "metadata": { + "id": "YdmveeAUz4YG" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Define a shallow neural network with, one input, one output, and three hidden units\n", + "def shallow_1_1_3(x, activation_fn, phi_0,phi_1,phi_2,phi_3, theta_10, theta_11, theta_20, theta_21, theta_30, theta_31):\n", + " # Initial lines\n", + " pre_1 = theta_10 + theta_11 * x\n", + " pre_2 = theta_20 + theta_21 * x\n", + " pre_3 = theta_30 + theta_31 * x\n", + " # Activation functions\n", + " act_1 = activation_fn(pre_1)\n", + " act_2 = activation_fn(pre_2)\n", + " act_3 = activation_fn(pre_3)\n", + " # Weight activations\n", + " w_act_1 = phi_1 * act_1\n", + " w_act_2 = phi_2 * act_2\n", + " w_act_3 = phi_3 * act_3\n", + " # Combine weighted activation and add y offset\n", + " y = phi_0 + w_act_1 + w_act_2 + w_act_3\n", + " # Return everything we have calculated\n", + " return y, pre_1, pre_2, pre_3, act_1, act_2, act_3, w_act_1, w_act_2, w_act_3" + ], + "metadata": { + "id": "ximCLwIfz8kj" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# # Plot the shallow neural network. We'll assume input in is range [-1,1] and output [-1,1]\n", + "# If the plot_all flag is set to true, then we'll plot all the intermediate stages as in Figure 3.3 \n", + "def plot_neural(x, y):\n", + " fig, ax = plt.subplots()\n", + " ax.plot(x.T,y.T)\n", + " ax.set_xlabel('Input'); ax.set_ylabel('Output')\n", + " ax.set_xlim([-1,1]);ax.set_ylim([-1,1])\n", + " ax.set_aspect(1.0)\n", + " plt.show()" + ], + "metadata": { + "id": "btrt7BX20gKD" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's define a networks. We'll just consider the inputs and outputs over the range [-1,1]. If you set the \"plot_all\" flat to True, you can see the details of how it was created." + ], + "metadata": { + "id": "LxBJCObC-NTY" + } + }, + { + "cell_type": "code", + "source": [ + "# Now lets define some parameters and run the first neural network\n", + "n1_theta_10 = 0.0 ; n1_theta_11 = -1.0\n", + "n1_theta_20 = 0 ; n1_theta_21 = 1.0\n", + "n1_theta_30 = -0.67 ; n1_theta_31 = 1.0\n", + "n1_phi_0 = 1.0; n1_phi_1 = -2.0; n1_phi_2 = -3.0; n1_phi_3 = 9.3\n", + "\n", + "# Define a range of input values\n", + "n1_in = np.arange(-1,1,0.01).reshape([1,-1])\n", + "\n", + "# We run the neural network for each of these input values\n", + "n1_out, *_ = shallow_1_1_3(n1_in, ReLU, n1_phi_0, n1_phi_1, n1_phi_2, n1_phi_3, n1_theta_10, n1_theta_11, n1_theta_20, n1_theta_21, n1_theta_30, n1_theta_31)\n", + "# And then plot it\n", + "plot_neural(n1_in, n1_out)" + ], + "metadata": { + "id": "JRebvurv22pT", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 283 + }, + "outputId": "998dc656-6e74-4fcd-fc9e-cc7649da1417" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAASMAAAEKCAYAAABZgzPTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3df5DcdZ3n8ed7ZjIZxoRk8gONIRKSoIjG48eAeu5ZtwqKugfuyipYKt7icd7q7e165wlllW6hlKJ3hbVq3ZIVBHct0OVuy2yJiyA/tFZR4h6/FekO4UhEupOQHz2TzGRm3vdHf7/hm870TE/392f361HVNd3f/vXpZPLO6/399vfzMXdHRCRrfVkPQEQEVIxEJCdUjEQkF1SMRCQXVIxEJBdUjEQkFzItRmZ2k5lVzOyxJvebmf2VmZXM7BEzOzty3+Vm9lRwuTy9UYtIErJORjcDF85x/zuA04LLlcD/AjCzFcBngdcD5wGfNbORREcqIonKtBi5+4+BvXM85GLgW173ALDczNYAbwfucve97v4CcBdzFzURybmBrAcwj7XAs5HbO4NtzbYfx8yupJ6qeMlLXnLO6aefnsxIJVb7xid59oVDnDxyAiPDg1kPR+axY/cYR2acA88+udvdV7fzGnkvRh1z9y3AFoDR0VHftm1bxiOS+UxNz3DB9T9mavcYv/fK1dzyJ+dlPSSZx+9ddw9nvWKEr73/7GfafY2s9xnNZxewLnL75GBbs+3SBbY+/Fue3j3GGWtO5J9Lu9k3Ppn1kGQOhyan2bXvEJtWL+nodfJejLYCHwqOqr0B2O/uzwF3Am8zs5Fgx/Xbgm1ScFPTM3z1nhJnrDmRL/zRZqZmnB8+/nzWw5I5bN9dwx02nvSSjl4n60P7twI/A15lZjvN7Aoz+6iZfTR4yB3AdqAE/A3wpwDuvhf4HPBgcLkm2CYFF6ai/3L+abzu5GWsW3EC33/0uayHJXMoVWoAbDqps2SU6T4jd79snvsd+FiT+24CbkpiXJKNaCp62xkvxcx45+Y13PiTp9k3Psly7cjOpXJ1jD6D9SsLnIxEoqKpyMwAeNfmNWrVcq5cqbFuxTBDi/o7eh0VI8mFxlQU2rxWrVrelSq1jndeg4qR5MRsqQg42qrpqFo+Tc84T+8e63h/EagYSQ40S0UhtWr59ezecSanZ9ioZCTdoFkqCqlVy69ytX4kbaOSkRTdfKkI1Krl2dHD+kpGUnTzpaKQWrV8KlVqrFqymGXDizp+LRUjyUwrqSikVi2fytUamzr85nVIxUgy02oqArVqeeTulCq1WHZeg4qRZGQhqSikVi1fqrUJDhyeiuWwPqgYSUYWkopCatXypVwZAzo/Jy2kYiSpaycVgVq1vCmFh/XVpklRtZOKQmrV8qNcqfGSwX7WLBuK5fVUjCRV7aaikFq1/ChXa2w8acmC/0NpRsVIUtVJKgK1ankS55E0UDGSFHWaikJq1bJXm5jiuf2HY9t5DSpGkqJOU1FIrVr2th/deR3PFx5BxUhSElcqArVqeRDXVLNRWc+BfaGZPRksX33VLPdfb2YPBZffmNm+yH3Tkfu2pjtyWai4UlFIrVq2ytUaA33GKR1ONRuV2RzYZtYPfB24gPoijA+a2VZ3fyJ8jLv/ReTx/xk4K/ISh9z9zLTGK+2LMxWFoq3ae89dN/8TJFalSo1XrBxmUX98eSbLZHQeUHL37e4+CdxGfTnrZi4Dbk1lZBKruFMRqFXLWlxTzUZlWYwWskT1KcCpwD2RzUNmts3MHjCzdyc3TOlEEqkopFYtG0emZ3hmz3is+4ugODuwLwVud/fpyLZT3H0UeD/wFTPbONsTzezKoGhtq1araYxVIpJIRSEdVcvGM3vGmZrxWL9jBNkWo4UsUX0pDS2au+8Kfm4H7uPY/UnRx21x91F3H129enWnY5YFSDIVgVq1rIRTzXZTMnoQOM3MTjWzQeoF57ijYmZ2OjBCfeXZcNuImS0Orq8C3gQ80fhcyVaSqSikVi194WH9DTF+xwgyLEbuPgV8HLgT+BXwXXd/3MyuMbOLIg+9FLgtWF029Gpgm5k9DNwLfDF6FE6yl3QqCqlVS1+5UuNlJw6xdKjzqWajsl7e+g7gjoZtn2m4/ZezPO+nwOZEBycdCVPRDR88J7FUBC+2aloCOz31qWbjbdGgODuwpUDSSkUhtWrpcXfK1bFYTwMJqRhJ7NLYVxSlVi09vztwmNpEfFPNRqkYSazSTkWgo2ppCqeajWPRxkYqRhKrtFNRSK1aOkqVg0A8izY2UjGS2GSRikJq1dJRqtZYOjTA6qWLY39tFSOJTVapCNSqpaVcGWNTjFPNRqkYSSyyTEUhtWrJK1XjnWo2SsVIYpFlKgqpVUvW/kNHqB6cSORIGqgYSQzykIpArVrSyjGvk9ZIxUg6lodUFFKrlpwkppqNUjGSjuQlFYXUqiWnXK0x2N/HupETEnl9FSPpSJ5SEahVS1K5UmP9qmEGYpxqNkrFSNqWt1QUUquWjFIlmRNkQypG0ra8paKQWrX4TUxN8//2jifyzeuQipG0Ja+pCNSqJWHH7nFmPJlz0kIqRtKWvKaikFq1eCV9WB9UjKQNeU5FIbVq8QoP66sYSa7kPRWBWrW4lSo11i4/gRMG+xN7DxUjWZAipKKQWrX4JDXVbFSmxcjMLjSzJ82sZGZXzXL/h82samYPBZePRO673MyeCi6Xpzvy3lWEVBRSqxaPmRmnnOAJsqHMJuQ3s37g68AF1FeTfdDMts6yysd33P3jDc9dAXwWGAUc+GXw3BdSGHrPKlIqAk3WH5dd+w5x+MhMVyej84CSu29390ngNuDiFp/7duAud98bFKC7gAsTGqcEipSKQmrVOpfUoo2NsixGa4FnI7d3BtsavcfMHjGz280sXIG21edqeeuYFC0VhdSqde7FI2nxrwgSlfcd2P8IrHf311FPP7cs9AW0vHU8ipiKQEfV4lCu1hgZXsTKJfFPNRuVZTHaBayL3D452HaUu+9x94ng5jeAc1p9rsSnqKkopFatM+FUs0nLshg9CJxmZqea2SD1Zay3Rh9gZmsiNy+ivgw21JfEfpuZjZjZCPC2YJskoKipKKRWrTNJTjUbldnRNHefMrOPUy8i/cBN7v64mV0DbHP3rcCfmdlFwBSwF/hw8Ny9ZvY56gUN4Bp335v6h+gBRU9FoKNqndg7NsnescmuT0a4+x3u/kp33+ju1wbbPhMUItz9and/jbv/K3f/fXf/deS5N7n7puDyzaw+Q7creioKqVVrTxrnpIXyvgNbMtQNqSikVq09SU81G6ViJE11SyoCHVVrV7lSY/FAH2uXJzPVbJSKkcyqm1JRSK3awpWqNTasXkJfX/L/GakYyay6KRWF1KotXNJTzUapGMlxujEVgVq1hTo0Oc2ufYcSnWo2SsVIjtONqSikVq1123fXcIeNJyV7GkhIxUiO0a2pKKRWrXVpHkkDFSNp0M2pCNSqLUS5OkafwfqVSkaSsm5PRSG1aq0pV2qsWzHM0KLkppqNUjGSo7o9FYXUqrWmXK2ltvMaVIwk0CupCNSqtWJ6xtm+eyzRddIaqRgJ0DupKKRWbW7P7h1ncmpGyUjS1UupKKRWbW5HT5BVMpI09VoqArVq8zl6WF/JSNLSi6kopFatuVKlxqoli1k2vCi191Qx6nG9mIpCatWaqy/amM73i0IqRj2sl1MRqFVrxt0pVdKZajZKxaiH9XIqCqlVO97u2iQHDk+ldhpIKO/LW3/CzJ4I1k37kZmdErlvOrLs9dbG58rcej0VhdSqHS/tc9JCmRWjyPLW7wDOAC4zszMaHvZ/gdFg3bTbgS9F7jvk7mcGl4tSGXQXUSqqU6t2vFKK815H5Xp5a3e/193Hg5sPUF8fTTqkVHQstWrHKldqDA/2s2bZUKrvW4TlrUNXAD+I3B4Klq1+wMze3exJWt76eEpFx1KrdqxysE5a2r8bhdiBbWYfAEaBL0c2n+Luo8D7ga+Y2cbZnqvlrY+lVHQ8tWrHSnOq2ahcL28NYGbnA58GLoosdY277wp+bgfuA85KcrDdQqlodmrV6moTUzy3/3DPFaNWlrc+C7iBeiGqRLaPmNni4Poq4E3AE6mNvKCUippTq1a3/ejO63S/8AgZFiN3nwLC5a1/BXw3XN46WNIa6m3ZEuDvGw7hvxrYZmYPA/cCX3R3FaN5KBU1p1atLqvD+gADqb9jhLvfAdzRsO0zkevnN3neT4HNyY6uuygVze9dm9dww/3b+eHjz/Pec9fN/4QuVK7WGOgzTklpqtmoQuzAls4pFc1PrVo9Gb1i5TCL+tMvDSpGPUCpqDVq1eqT8Kc5bUiUilEPUCpqXS8fVTsyPcOOlKeajVIx6nJKRQvTy63aM3vGmZpxJSNJhlLRwvRyqxZONZvFkTRQMepqSkXt6dVWLTysvyGD7xhBi8XIzN7UyjbJF6Wi9vRqq1au1HjZiUMsHUpvqtmoVpPRV1vcJjmhVNS+Xm3V6lPNZtOiwTzFyMzeaGb/FVgdTHQWXv4SSGfNW2mLUlFneq1Vc3fK1bFMTgMJzZeMBqmfjjEALI1cDgCXJDs0aZdSUed6rVV7/sAEtYn0p5qNmvN0EHe/H7jfzG5292dSGpN0KExFN3zwHKWiNoWt2o0/eZp945MsHx7MekiJCndeZ/UdI2h9n9HNZnZP4yXRkUlblIri00utWqlyEEh30cZGrZ4o+98i14eA9wBT8Q9HOqVUFJ9oq9btJ86Wq2MsHRpg9dLFmY2hpWLk7r9s2PTPZvaLBMYjHVAqilcvtWrhOmlZ/gfW6veMVkQuq8zs7cCyhMcmC6QjaPHrlVatlPFhfWi9Tfsl4IBRb8+epj5BvuSEUlEyeqFV23/oCNWDE8UoRu5+atIDkc5oX1EyeqFVK2e0TlqjVtu0oeDLjv/HzP63mf25maW7qJI0pVSUrG5v1bKcajaq1UP73wJeQ/0UkK8F1/+20zdvYXnrxWb2neD+n5vZ+sh9Vwfbnwz2YfUs7StKVrd/AbJcrTHY38e6kRMyHUer+4xe6+7RpafvNbOOJsCPLG99AfUFHB80s60NE+tfAbzg7pvM7FLgOuB9wTLYl1Ivii8H7jazV7r7dCdjKiKlouR1e6tWrtRYv2qYgQymmo1q9d3/xczeEN4ws9cD2zp873mXtw5u3xJcvx14q9X/678YuM3dJ9z9aaAUvF7PUSpKRze3auXqWOYtGrRejM4BfmpmO8xsB/Az4Fwze9TMHmnzvVtZ3vroY4KljfYDK1t8LtDdy1srFaWnW1u1ialpntkzlvnOa2i9Tbsw0VEkyN23AFsARkdHPePhxEpH0NLTra3ajt3jzHj2O6+h9WT0eXd/JnqJbmvzvVtZ3vroY8xsgPoXLfe0+NyuplSUvm5s1fJyWB9aL0avid4ICsM5Hb73vMtbB7cvD65fAtzj7h5svzQ42nYqcBrQU6enaF9R+rqxVct6qtmo+SZXu9rMDgKvM7MDZnYwuP088L1O3rjF5a1vBFaaWQn4BHBV8NzHge8CTwD/BHysl46kKRVloxtngCxVaqxdfgLDg5kuLg3MU4zc/QvuvhT4sruf6O5Lg8tKd7+60zd39zvc/ZXuvtHdrw22fcbdtwbXD7v7H7v7Jnc/z923R557bfC8V7n7DzodS5EoFWWn21q1rKeajWq1TfuBmb258ZLoyGRWSkXZ6qZWbWbGKVdrudhfBK0fTftk5PoQ9e/0/BJ4S+wjkjnpCFq2uumo2q59hzh8ZKZYycjd/13kcgHwWuCFZIcmjZSK8qFbWrWsF21s1O73v3cCr45zIDI/7SvKh25p1Y7Oe52DI2nQYptmZl+lPp8R1AvYWcC/JDUoOZ5SUX50S6tWro4xMryIlUuym2o2qtVk9ATwm+DyAPDf3f0DiY1KjqNUlC/d0KqVK/nZeQ3zf89owMy+BHwO+JPg8hXgYjPLZg3cHqRUlD/d0KrlYarZqPmS0ZeBFcCp7n62u58NbACWA/8j6cFJnVJR/hT9C5B7xybZOzZZqGL0B8B/cPeD4QZ3PwD8J+CdSQ5M6pSK8qvIrVqezkkLzVeMPDgXrHHjNC/u0JYEKRXlV5FbtbxMNRs1XzF6wsw+1LjRzD4A/DqZIUlIqSjfityqlSs1Fg/0sXZ5tlPNRs1XjD4GfMzM7jOz/xlc7gf+jHqrJglSKsq/orZqpWqNDauX0NeXn9+r+U6U3eXurweuAXYEl2uCk1Z7av6gtCkVFUNRW7U8nSAbavV0kHvc/avB5UdJD0qUioqiiK3a4SPT7HzhEJtytPMa2j8dRBKkVFQsRWvVytUa7rDxpHycBhJSMcohpaJiKVqrVq6OAfk6kgYqRrmjVFQ8RWvVSpUafQbrVyoZyRyUioqpSK1auVJj3Yphhhb1Zz2UY2RSjMxshZndZWZPBT9HZnnMmWb2MzN73MweMbP3Re672cyeNrOHgsuZ6X6CZCgVFVeRWrVytZa7ndeQXTK6CviRu58G/Ci43Wgc+JC7v4b6um1fMbPlkfs/6e5nBpeHkh9y8pSKiqsordr0jLN99xgbc7a/CLIrRtFlq28B3t34AHf/jbs/FVz/LVABVqc2wpQpFRVfEVq1Z/eOMzk1o2QU8VJ3D/Ps74A5//WZ2XnAIFCObL42aN+uN7Oms0MVZXlrpaLiK0KrdvQE2V5KRmZ2t5k9Nsvl4ujjghNxm550a2ZrgL8F/r27zwSbrwZOB86lPsXJp5o93923uPuou4+uXp3PYKVU1B2K0KodPUG2l5KRu5/v7q+d5fI94PmgyITFpjLba5jZicD3gU+7+wOR137O6yaAb1JfraSwlIq6R95btXK1xqoli1k2nL+5EbNq06LLVl/OLKvTBkte/wPwLXe/veG+sJAZ9f1NjyU62gQpFXWXvLdqpUotNxPwN8qqGH0RuMDMngLOD25jZqNm9o3gMe8F3gx8eJZD+N82s0eBR4FVwOfTHX58lIq6S55bNXenVMnfCbKhTBbYdvc9wFtn2b4N+Ehw/e+Av2vy/K5YPFKpqDu9a/Mabrh/Oz98/Hnee+66rIdz1O7aJAcOT+W2GOkb2BlSKupOeW3VXlwnTcVIIpSKuldeW7VSzlaQbaRilBGlou6Wx6Nq5UqN4cF+1iwbynoos1IxyoBSUffLY6tWrtYXbczrf34qRhlQKup+eWzVyjk+kgYqRqlTKuodeWrVxiam+O3+wypG8iKlot6Rp1btxUUb8/mFR1AxSpVSUW/JU6tWzvmRNFAxSpVSUe/JS6tWqtTo7zNesULJqOcpFfWmvLRqpUqNU1YOMziQ33/y+R1Zl1Eq6k15adXK1bFcThsSpWKUAqWi3pZ1q3ZkeoYdOZ1qNkrFKAVKRb0t61btmT3jTM24klGvUyqSrFu1IhxJAxWjxCkVCWTbqoVn62/I8XeMQMUoUUpFEsqyVStXa7zsxCGWDuVvqtkoFaMEKRVJKMtWLe/npIVUjBKiVCSNsmjV3J1ydSzXp4GEcru8dfC46cj811sj2081s5+bWcnMvhNM3p8rSkXSKItW7fkDE9Qm8jvVbFSel7cGOBRZwvqiyPbrgOvdfRPwAnBFssNdGKUimU0WrVrep5qNyu3y1s0EyxO9BQiXL1rQ89OgVCTNpN2qlSoHgfwf1of8L289FCxN/YCZhQVnJbDP3aeC2zuBtc3eKO3lrZWKZC5pt2rl6hhLhwZYvbTpCvC5kdhSRWZ2N/CyWe76dPSGu7uZNVve+hR332VmG4B7grXS9i9kHO6+BdgCMDo62nQZ7biEqeiGD56jVCTHCVu1G3/yNPvGJ1k+nOzuzvqijfmdajYq18tbu/uu4Od24D7gLGAPsNzMwkJ6MrArqc+xEEpF0oo0W7VStRiH9SHfy1uPmNni4Poq4E3AE+7uwL3AJXM9PwvaVyStSKtV23/oCNWDEypG82hleetXA9vM7GHqxeeL7v5EcN+ngE+YWYn6PqQbUx39LJSKpFVpHVV7carZYhSjPC9v/VNgc5PnbwfOS3KMC6V9RbIQaSyBXa4U4wTZkL6BHQOlIlmoNFq1UrXGYH8f60ZOSOw94qRiFAPtK5KFSqNVK1dqrF81zEB/Mf6ZF2OUOaZUJO1K+qhauTpWmBYNVIw6plQk7UqyVZuYmuaZPWOF2XkNKkYdUSqSTiTZqu3YPc6MF2fnNagYdUSpSDqVVKtWtMP6oGLUNqUiiUNSrVpRppqNUjFqk1KRxCGpVq1crbF2+QkMD2byVcK2qBi1QalI4pREq1YqyFSzUSpGbVAqkjjF3arNzDjlaq1Q+4tAxWjBlIokbnG3ar/df4jDR2aUjLqdUpEkIc5W7cWpZouz8xpUjBZEqUiSEmerVirYCbIhFaMFUCqSpMTZqpWrY4wML2LlkvxPNRulYtQipSJJWlytWrlSvJ3XoGLUMqUiSVpcrVqRppqNUjFqgVKRpCGOVm3v2CR7xyZVjLqVUpGkpdNWrYjnpIVyu7y1mf1+ZGnrh8zscLh2mpndbGZPR+47M6mxKhVJmjpt1Yo21WxUbpe3dvd7w6Wtqa8gOw78MPKQT0aWvn4oqYEqFUmaOm3VSpUaiwf6WLu8GFPNRhVleetLgB+4+3iio2qgVCRZ6KRVK1VrbFi9hL6+4v3HmfflrUOXArc2bLvWzB4xs+vD9dXiplQkWeikVSsX9EgaJFiMzOxuM3tslsvF0ccFizI2XXY6WHF2M3BnZPPVwOnAucAK6uuoNXv+lWa2zcy2VavVlsevVCRZabdVO3xkmp0vHCrcaSChXC9vHXgv8A/ufiTy2s953QTwTeZYQ83dt7j7qLuPrl69uuXxKxVJltpp1crVGl6wqWajcru8dcRlNLRokUJm1Pc3PRbn4JSKJGvttGrl6higYrRQrSxvjZmtB9YB9zc8/9tm9ijwKLAK+Hycg1Mqkqy106qVKjX6DNavVJvWMnff4+5vdffTgnZub7B9m7t/JPK4He6+1t1nGp7/FnffHLR9H3D3WlxjUyqSvFhoq1au1li3YpihRf0JjywZ+gZ2A6UiyYuFtmrlSo1NBfzmdUjFKEKpSPJkIa3a9IyzffcYGwu6vwhUjI6hVCR502qrtvOFcSanZpSMuoFSkeRRq63a0almTyrmzmtQMTpKqUjyqNVW7ehUs6uXpjW02KkYoVQk+dZKq1au1li1ZDHLhhelOLJ4qRihVCT51kqrVqrUCnsaSKjni5FSkeTdfK2auxdyBdlGPV+MlIqkCOZq1XbXJjlweErFqMiUiqQo5mrVXly0UcWosJSKpCjmatXCea+VjApKqUiKplmrVqrUGB7sZ82yoYxGFo+eLUZKRVI0zVq1crW+aGPRf497shgpFUkRNWvVyl1wJA16tBgpFUlRNbZqYxNT/Hb/4cJ/xwh6sBgpFUmRNbZq3bLzGnqwGCkVSZE1tmoqRgXloFQkhRdt1UqVGv19xitWFL9NG8h6AGnaPz7Jnt1j3PDBc5SKpLCirdrQoj5OWTnM4EDxc0Umn8DM/tjMHjezGTMbneNxF5rZk2ZWMrOrIttPNbOfB9u/Y2aDrbzv8wcnlIqk8KKt2sPP7i/0hGpRWZXTx4A/An7c7AFm1g98HXgHcAZwmZmdEdx9HXC9u28CXgCuaOVNJ6dmtK9IukLYqv3uwOFCTzUbldXqIL9y9yfnedh5QMndt7v7JHAbcHGwVtpbgNuDx91Cfe20eQ0t6lcqkq4QtmpA1ySjPO8zWgs8G7m9E3g9sBLY5+5Tke1rm72ImV0JXBncnOjr64t1wcecWAXsznoQCenWzxbb57rkujheJTavaveJiRUjM7sbeNksd306WOI6Fe6+BdgSjGmbuzfdR1VU3fq5oHs/Wzd/rnafm1gxcvfzO3yJXdRXkw2dHGzbAyw3s4EgHYXbRaTA8nw88EHgtODI2SBwKbDV3R24F7gkeNzlQGpJS0SSkdWh/T80s53AG4Hvm9mdwfaXm9kdAEHq+ThwJ/Ar4Lvu/njwEp8CPmFmJer7kG5s8a23xPgx8qRbPxd072fT52pg9aAhIpKtPLdpItJDVIxEJBe6uhh1etpJXpnZCjO7y8yeCn6ONHnctJk9FFy2pj3OVs33529mi4PTfkrBaUDr0x9le1r4bB82s2rk7+kjWYxzIczsJjOrmNms39mzur8KPvMjZnZ2Sy/s7l17AV5N/UtY9wGjTR7TD5SBDcAg8DBwRtZjn+dzfQm4Krh+FXBdk8fVsh5rC59l3j9/4E+Bvw6uXwp8J+txx/jZPgx8LeuxLvBzvRk4G3isyf3vBH4AGPAG4OetvG5XJyPv4LST5EfXkYupnwYDCzgdJqda+fOPft7bgbdaMU4wLOLv1rzc/cfA3jkecjHwLa97gPr3AtfM97pdXYxaNNtpJ01PL8mJl7p7OCv774BmJ9wNmdk2M3vAzPJasFr58z/6GK9/5WM/9a905F2rv1vvCdqZ281s3Sz3F01b/6byfG5aS/Jy2knc5vpc0Rvu7mbW7PsZp7j7LjPbANxjZo+6eznusUpH/hG41d0nzOw/Uk+Ab8l4TJkofDHy5E47ydRcn8vMnjezNe7+XBB/K01eY1fwc7uZ3QecRX0fRp608ucfPmanmQ0Ay6ifFpR38342d49+jm9Q3x9YdG39m1Kb1uS0k4zHNJ+t1E+DgSanw5jZiJktDq6vAt4EPJHaCFvXyp9/9PNeAtzjwZ7SnJv3szXsS7mI+tkGRbcV+FBwVO0NwP7IboXmst4zn/Be/z+k3q9OAM8DdwbbXw7c0bD3/zfUU8Onsx53C59rJfAj4CngbmBFsH0U+EZw/V8Dj1I/gvMocEXW457j8xz35w9cA1wUXB8C/h4oAb8ANmQ95hg/2xeAx4O/p3uB07Mecwuf6VbgOeBI8O/rCuCjwEeD+436xIjl4Hdv1iPZjRedDiIiuaA2TURyQcVIRHJBxUhEckHFSERyQcVIRHJBxUhSZWa1BF5zvZm9P+7XlXSpGEk3WA+oGBWcipFkwsz+rZndF5wc+msz+3Z4Jr6Z7TCzL5nZo2b2CzPbFGy/2cwuibxGmLK+CLbEJu4AAADiSURBVPybYD6gv0j/00gcVIwkS2cBf059+fIN1E9ZCe13983A14CvzPM6VwE/cfcz3f36REYqiVMxkiz9wt13uvsM8BD1dit0a+TnG9MemKRPxUiyNBG5Ps2xs0j4LNenCH5nzayP+uyJ0iVUjCSv3hf5+bPg+g7gnOD6RcCi4PpBYGlqI5NEFH4+I+laI2b2CPX0dFmw7W+A75nZw8A/AWPB9keA6WD7zdpvVEw6a19yx8x2UJ92YnfWY5H0qE0TkVxQMhKRXFAyEpFcUDESkVxQMRKRXFAxEpFcUDESkVz4/9a9Mxo3r6mpAAAAAElFTkSuQmCC\n" + }, + "metadata": { + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now we'll define the same neural network, but this time, we will use matrix form. When you get this right, it will draw the same plot as above." + ], + "metadata": { + "id": "XCJqo_AjfAra" + } + }, + { + "cell_type": "code", + "source": [ + "beta_0 = np.zeros((3,1))\n", + "Omega_0 = np.zeros((3,1))\n", + "beta_1 = np.zeros((1,1))\n", + "Omega_1 = np.zeros((1,3))\n", + "\n", + "# TODO Fill in the values of the beta and Omega matrices with the n1_theta and n1_phi parameters that define the network above\n", + "# !!! NOTE THAT MATRICES ARE CONVENTIONALLY INDEXED WITH a_11 IN THE TOP LEFT CORNER, BUT NDARRAYS START AT [0,0]\n", + "# To get you started I've filled in a couple:\n", + "beta_0[0,0] = n1_theta_10\n", + "Omega_0[0,0] = n1_theta_11\n", + "\n", + "\n", + "# Make sure that input data matrix has different inputs in its columns\n", + "n_data = n1_in.size\n", + "n_dim_in = 1\n", + "n1_in_mat = np.reshape(n1_in,(n_dim_in,n_data))\n", + "\n", + "# This runs the network for ALL of the inputs, x at once so we can draw graph\n", + "h1 = ReLU(np.matmul(beta_0,np.ones((1,n_data))) + np.matmul(Omega_0,n1_in_mat))\n", + "n1_out = np.matmul(beta_1,np.ones((1,n_data))) + np.matmul(Omega_1,h1)\n", + "\n", + "# Draw the network and check that it looks the same as the non-matrix case\n", + "plot_neural(n1_in, n1_out)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 283 + }, + "id": "MR0AecZYfACR", + "outputId": "02f94f92-73c5-48e7-e7ce-695b9774a893" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAASMAAAEKCAYAAABZgzPTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAATTklEQVR4nO3dfZBddX3H8fdHYhLHqiQkg+EhkAy0PNROAlcepLUWEdFpE62phI4lKDY+0U5ldAjlDx3UCtoZGMWppBgDjBPQtI7rCE2BgDiWABsbE4hiloCaFCUSwAcUSPj2j/NbPCx7d8/u3nPP7+5+XjN37jm/c87d79lkP3POuffcryICM7OmvaTpAszMwGFkZplwGJlZFhxGZpYFh5GZZcFhZGZZaDSMJK2R9Kik+9osl6TPSRqQtFXSCaVlKyTtSI8V3avazOrQ9JHRWuCsEZa/BTg6PVYC/wYgaTbwMeBk4CTgY5Jm1VqpmdWq0TCKiDuBvSOsshS4LgqbgAMlzQPeDNwSEXsj4nHgFkYONTPL3LSmCxjFocBPS/O70li78ReRtJLiqIqXv/zlJx5zzDH1VGpmbN68+RcRMXc82+YeRhMWEauB1QCtViv6+/sbrshs8pL04/Fu2/Q1o9HsBg4vzR+WxtqNm1mPyj2M+oBz07tqpwBPRsQjwAbgTEmz0oXrM9OYmfWoRk/TJK0D3gDMkbSL4h2ylwJExBeBm4C3AgPAU8C707K9kj4B3Jte6tKIGOlCuJllrtEwiohzRlkewIfaLFsDrKmjLjPrvtxP08xsinAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFppub32WpAdS++pVwyy/QtKW9PiRpCdKy/aXlvV1t3Iz67TGvgNb0gHAF4A3UTRhvFdSX0RsH1wnIj5cWv8fgMWll/htRCzqVr1mVq8mj4xOAgYiYmdEPAPcQNHOup1zgHVdqczMuq7JMBpLi+ojgAXAxtLwTEn9kjZJelt9ZZpZN/RKe+vlwPqI2F8aOyIidktaCGyUtC0iHhy6oaSVwEqA+fPnd6daMxuzJo+MxtKiejlDTtEiYnd63gncwQuvJ5XXWx0RrYhozZ07d6I1m1lNmgyje4GjJS2QNJ0icF70rpikY4BZwF2lsVmSZqTpOcBpwPah25pZ72jsNC0i9km6ANgAHACsiYj7JV0K9EfEYDAtB25I3WUHHQtcLek5ikC9rPwunJn1Hr3wb3xya7Va0d/f33QZZpOWpM0R0RrPtv4EtpllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYaDSNJZ0l6QNKApFXDLD9P0h5JW9LjvaVlKyTtSI8V3a3czDqtse4gkg4AvgC8iaKb7L2S+obp8nFjRFwwZNvZwMeAFhDA5rTt410o3cxq0OSR0UnAQETsjIhngBuApRW3fTNwS0TsTQF0C3BWTXWaWRc0GUaHAj8tze9KY0O9Q9JWSeslDXagrbotklZK6pfUv2fPnk7UbWY1yP0C9jeBIyPiTyiOfq4d6wu4vbVZb2gyjHYDh5fmD0tjz4uIxyLi6TR7DXBi1W3NrLc0GUb3AkdLWiBpOkUb677yCpLmlWaXAD9I0xuAMyXNkjQLODONmVmPauzdtIjYJ+kCihA5AFgTEfdLuhToj4g+4B8lLQH2AXuB89K2eyV9giLQAC6NiL1d3wkz6xhFRNM1dE2r1Yr+/v6myzCbtCRtjojWeLbN/QK2mU0RDiMzy4LDyMyy4DAysyw4jMwsCw4jM8uCw8jMsuAwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLDiMzCwLDiMzy4LDyMyy4DAysyw4jMwsCw4jM8tC7u2tL5S0PfVNu03SEaVl+0ttr/uGbmtmvSX39tb/C7Qi4ilJHwA+A5ydlv02IhZ1tWgzq03W7a0j4vaIeCrNbqLoj2Zmk1AvtLcedD5wc2l+ZmpbvUnS29pt5PbWZr2hsdO0sZD0LqAF/Hlp+IiI2C1pIbBR0raIeHDothGxGlgNRauirhRsZmOWdXtrAElnAJcAS0qtromI3el5J3AHsLjOYs2sXrm3t14MXE0RRI+WxmdJmpGm5wCnAeUL32bWY3Jvb/1Z4A+Ar0kC+ElELAGOBa6W9BxFoF425F04M+sxbm9tZh3j9tZm1vMcRmaWBYeRmWXBYWRmWXAYmVkWHEZmloVKYSTptCpjZmbjVfXI6PMVx8zMxmXET2BLOhV4HTBX0oWlRa+k+NS0mVlHjHY7yHSK2zGmAa8ojf8SWFZXUWY29YwYRhHxbeDbktZGxI+7VJOZTUFVb5RdK+lFN7FFxOkdrsfMpqiqYfSR0vRM4B3Avs6XY2ZTVaUwiojNQ4a+K+meGuoxsymqUhhJml2afQlwIvCqWioysymp6mnaZiAAUZyePUTxBflmZh1R9TRtQd2FmNnUVvU0bSbwQeBPKY6QvgN8MSJ+V2NtZjaFVL0d5DrgeIpbQK5K09dP9IdXaG89Q9KNafndko4sLbs4jT8g6c0TrcXMmlX1mtEfR8RxpfnbJU3oC/Artrc+H3g8Io6StBy4HDhb0nEU3USOBw4BbpX0hxGxfyI1mVlzqobR9ySdEhGbACSdDEz0m+2fb2+dXnOwvXU5jJYCH0/T64GrVLQJWQrckPqoPSRpIL3eXSP9wJ17fsPZV4+4ipmN03GHvHJC21c9TTsR+B9JD0t6mOKP/rWStknaOs6fXaW99fPrRMQ+4EngoIrbAi9sb/3ss8+Os1Qzq1vVI6Ozaq2iRkPbW9/4vlMbrshs8vr4BLatGkafjIi/Kw9Iun7o2BhVaW89uM4uSdMoPmj5WMVtzayHVD1NO748k4LhxAn+7FHbW6f5FWl6GbAxiq6TfcDy9G7bAuBowLenmPWw0b5c7WLgn4GXSfolxSewAZ4hnfqMV8X21l8Crk8XqPdSBBZpva9SXOzeB3zI76SZ9bZK7a0lfToiLu5CPbVye2uzek2kvXXVa0Y3S3r90MGIuHM8P9TMbKiqYfTR0vRMis/0bAb85Wpm1hFVb5T9q/K8pMOBK2upyMympPE2cdwFHNvJQsxsaqt61/7nKe7WhyLAFgPfq6soM5t6ql4z2s7v+6Q9AayLiO/WU5KZTUWjfc5oGvAvwHuAn6Th+cAaSfdEhG/2MrOOGO2a0WeB2cCCiDghIk4AFgIHAv9ad3FmNnWMFkZ/Cfx9RPxqcCAifgl8AHhrnYWZ2dQyWhhFDPMR7XTrxegf3TYzq2i0MNou6dyhg5LeBfywnpLMbCoa7d20DwH/Kek9FJ+4BmgBLwPeXmdhZja1jBhGEbEbOFnS6fz+a0Ruiojbaq/MzKaUqreDbAQ21lyLmU1h470dxMysoxxGZpYFh5GZZcFhZGZZaCSMJM2WdIukHel51jDrLJJ0l6T7JW2VdHZp2VpJD0nakh6LursHZtZpTR0ZrQJui4ijgdvS/FBPAedGxPEUfduulHRgaflHI2JRemypv2Qzq1NTYbQUuDZNXwu8begKEfGjiNiRpv8PeBSY27UKzayrmgqjgyPikTT9M+DgkVaWdBIwHXiwNPypdPp2haQZI2z7fHvrPXv2TLhwM6tHbWEk6VZJ9w3zWFpeL92I2/amW0nzgOuBd0fEc2n4YuAY4LUUX3FyUbvtI2J1RLQiojV3rg+szHJV9Zsexywizmi3TNLPJc2LiEdS2DzaZr1XAt8CLomITaXXHjyqelrSl4GPdLB0M2tAU6dp5bbVK4BvDF0htbz+OnBdRKwfsmxeehbF9ab7aq3WzGrXVBhdBrxJ0g7gjDSPpJaka9I67wReD5w3zFv4X5G0DdgGzAE+2d3yzazTKrW3nizc3tqsXhNpb+1PYJtZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWhWzbW6f19pe+/7qvNL5A0t2SBiTdmL6838x6WM7trQF+W2phvaQ0fjlwRUQcBTwOnF9vuWZWt2zbW7eT2hOdDgy2LxrT9maWp9zbW89Mrak3SRoMnIOAJyJiX5rfBRza7ge5vbVZb6ito6ykW4FXD7PokvJMRISkdv2SjoiI3ZIWAhtTr7Qnx1JHRKwGVkPRqmgs25pZ92Td3joidqfnnZLuABYD/wEcKGlaOjo6DNjd8R0ws67Kub31LEkz0vQc4DRgexRdJ28Hlo20vZn1lpzbWx8L9Ev6PkX4XBYR29Oyi4ALJQ1QXEP6UlerN7OOc3trM+sYt7c2s57nMDKzLDiMzCwLDiMzy4LDyMyy4DAysyw4jMwsCw4jM8uCw8jMsuAwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLDiMzCwLDiMzy4LDyMyy4DAysyxk295a0l+UWltvkfS7wd5pktZKeqi0bFH398LMOinb9tYRcftga2uKDrJPAf9dWuWjpdbXW7pStZnVplfaWy8Dbo6Ip2qtyswak3t760HLgXVDxj4laaukKwb7q5lZ78q9vTWp4+xrgA2l4YspQmw6Revqi4BL22y/ElgJMH/+/DHsgZl1U9btrZN3Al+PiGdLrz14VPW0pC8DHxmhjtUUgUWr1Zo6TeLMeky27a1LzmHIKVoKMCSJ4nrTfTXUaGZdlHN7ayQdCRwOfHvI9l+RtA3YBswBPtmFms2sRrWdpo0kIh4D3jjMeD/w3tL8w8Chw6x3ep31mVn3+RPYZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWWgkjCT9jaT7JT0nqTXCemdJekDSgKRVpfEFku5O4zdKmt6dys2sLk0dGd0H/DVwZ7sVJB0AfAF4C3AccI6k49Liy4ErIuIo4HHg/HrLNbO6NRJGEfGDiHhglNVOAgYiYmdEPAPcACxNvdJOB9an9a6l6J1mZj2skVZFFR0K/LQ0vws4GTgIeCIi9pXGX9TOaFC5vTVFB9rJ2PBxDvCLpouoyWTdt8m6X3803g1rCyNJtwKvHmbRJRExUgfZjiq3t5bUHxFtr1H1qsm6XzB5920y79d4t60tjCLijAm+xG6KbrKDDktjjwEHSpqWjo4Gx82sh+X81v69wNHpnbPpwHKgLyICuB1YltZbAXTtSMvM6tHUW/tvl7QLOBX4lqQNafwQSTcBpKOeC4ANwA+Ar0bE/eklLgIulDRAcQ3pSxV/9OoO7kZOJut+weTdN+/XECoONMzMmpXzaZqZTSEOIzPLwqQOo4nedpIrSbMl3SJpR3qe1Wa9/ZK2pEdft+usarTfv6QZ6bafgXQb0JHdr3J8KuzbeZL2lP6d3ttEnWMhaY2kR9t9Zk+Fz6V93irphEovHBGT9gEcS/EhrDuAVpt1DgAeBBYC04HvA8c1Xfso+/UZYFWaXgVc3ma9Xzdda4V9GfX3D3wQ+GKaXg7c2HTdHdy384Crmq51jPv1euAE4L42y98K3AwIOAW4u8rrTuojo5jAbSf1VzchSylug4Hevx2myu+/vL/rgTem24Jy14v/t0YVEXcCe0dYZSlwXRQ2UXwucN5orzupw6ii4W47aXt7SSYOjohH0vTPgIPbrDdTUr+kTZJyDawqv//n14niIx9PUnykI3dV/2+9I53OrJd0+DDLe824/qZyvjetklxuO+m0kfarPBMRIand5zOOiIjdkhYCGyVti4gHO12rTcg3gXUR8bSk91EcAZ7ecE2N6PkwivpuO2nUSPsl6eeS5kXEI+nw99E2r7E7Pe+UdAewmOIaRk6q/P4H19klaRrwKorbgnI36r5FRHk/rqG4HtjrxvU35dO0NredNFzTaPooboOBNrfDSJolaUaangOcBmzvWoXVVfn9l/d3GbAx0pXSzI26b0OupSyhuNug1/UB56Z31U4BnixdVmiv6SvzNV/1fzvF+erTwM+BDWn8EOCmIVf/f0Rx1HBJ03VX2K+DgNuAHcCtwOw03gKuSdOvA7ZRvIOzDTi/6bpH2J8X/f6BS4ElaXom8DVgALgHWNh0zR3ct08D96d/p9uBY5quucI+rQMeAZ5Nf1/nA+8H3p+Wi+KLER9M//eGfSd76MO3g5hZFnyaZmZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYWVdJ+nUNr3mkpL/t9OtadzmMbDI4EnAY9TiHkTVC0hsk3ZFuDv2hpK8M3okv6WFJn5G0TdI9ko5K42slLSu9xuBR1mXAn6XvA/pw9/fGOsFhZE1aDPwTRfvyhRS3rAx6MiJeA1wFXDnK66wCvhMRiyLiiloqtdo5jKxJ90TEroh4DthCcbo1aF3p+dRuF2bd5zCyJj1dmt7PC79FIoaZ3kf6PyvpJRTfnmiThMPIcnV26fmuNP0wcGKaXgK8NE3/CnhF1yqzWvT89xnZpDVL0laKo6dz0ti/A9+Q9H3gv4DfpPGtwP40vtbXjXqT79q37Eh6mOJrJ37RdC3WPT5NM7Ms+MjIzLLgIyMzy4LDyMyy4DAysyw4jMwsCw4jM8vC/wOu8sqPbOGqhwAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now we'll feed the output of the first network into the second one." + ], + "metadata": { + "id": "qOcj2Rof-o20" + } + }, + { + "cell_type": "code", + "source": [ + "# Now lets define some parameters and run the second neural network\n", + "n2_theta_10 = -0.6 ; n2_theta_11 = -1.0\n", + "n2_theta_20 = 0.2 ; n2_theta_21 = 1.0\n", + "n2_theta_30 = -0.5 ; n2_theta_31 = 1.0\n", + "n2_phi_0 = 0.5; n2_phi_1 = -1.0; n2_phi_2 = -1.5; n2_phi_3 = 2.0\n", + "\n", + "# Define a range of input values\n", + "n2_in = np.arange(-1,1,0.01)\n", + "\n", + "# We run the second neural network on the output of the first network\n", + "n2_out, *_ = \\\n", + " shallow_1_1_3(n1_out, ReLU, n2_phi_0, n2_phi_1, n2_phi_2, n2_phi_3, n2_theta_10, n2_theta_11, n2_theta_20, n2_theta_21, n2_theta_30, n2_theta_31)\n", + "# And then plot it\n", + "plot_neural(n1_in, n2_out)" + ], + "metadata": { + "id": "ZRjWu8i9239X", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 283 + }, + "outputId": "a6284139-2940-476c-963c-015649963a58" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAASMAAAEKCAYAAABZgzPTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAATUElEQVR4nO3dfZBddX3H8fdHYhLHqiQkg+EhkAy0PNROAldAaa1FBHTaRGsqoWMJik19oJ3K6BDKTJ1BraDO4PgwlRRigHECmtZxHaEpEBBHCbCxMYEoZgmo2aIgAXxAgwnf/nF+q8dl7+7Z3Xvu+d3dz2vmzj3nd865+z2bzWfOOfee+1VEYGbWtBc0XYCZGTiMzCwTDiMzy4LDyMyy4DAysyw4jMwsC42GkaR1kh6TdH+b5ZL0KUkDkrZLOqm0bJWkXemxqntVm1kdmj4yWg+cM8ryNwDHpsdq4N8BJM0FPgicCpwCfFDSnForNbNaNRpGEXEXsHeUVZYD10dhC3CwpAXA2cCtEbE3Ip4EbmX0UDOzzM1ouoAxHA78qDS/J421G38eSaspjqp48YtffPJxxx1XT6VmxtatW38aEfMnsm3uYTRpEbEWWAvQarWiv7+/4YrMpi5JP5jotk1fMxrLIHBkaf6INNZu3Mx6VO5h1Aecn95VOw14OiIeBTYBZ0maky5cn5XGzKxHNXqaJmkD8FpgnqQ9FO+QvRAgIj4H3Ay8ERgAngHenpbtlfQh4L70UpdHxGgXws0sc42GUUScN8byAN7bZtk6YF0ddZlZ9+V+mmZm04TDyMyy4DAysyw4jMwsCw4jM8uCw8jMsuAwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLDiMzCwLDiMzy4LDyMyy4DAysyw4jMwsCw4jM8uCw8jMstB0e+tzJD2Y2levGWH5VZK2pcf3JT1VWnagtKyvu5WbWac19h3Ykg4CPgu8nqIJ432S+iJi59A6EfG+0vr/CCwtvcSvImJJt+o1s3o1eWR0CjAQEbsj4lngRop21u2cB2zoSmVm1nVNhtF4WlQfBSwCNpeGZ0vql7RF0pvqK9PMuqFX2luvBDZGxIHS2FERMShpMbBZ0o6IeGj4hpJWA6sBFi5c2J1qzWzcmjwyGk+L6pUMO0WLiMH0vBu4k9+/nlReb21EtCKiNX/+/MnWbGY1aTKM7gOOlbRI0kyKwHneu2KSjgPmAHeXxuZImpWm5wGnAzuHb2tmvaOx07SI2C/pImATcBCwLiIekHQ50B8RQ8G0ErgxdZcdcjxwtaTnKAL1ivK7cGbWe/T7/8entlarFf39/U2XYTZlSdoaEa2JbOtPYJtZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWWh0TCSdI6kByUNSFozwvILJD0uaVt6vLO0bJWkXemxqruVm1mnNdYdRNJBwGeB11N0k71PUt8IXT5uioiLhm07F/gg0AIC2Jq2fbILpZtZDZo8MjoFGIiI3RHxLHAjsLzitmcDt0bE3hRAtwLn1FSnmXVBk2F0OPCj0vyeNDbcWyRtl7RR0lAH2qrbImm1pH5J/Y8//ngn6jazGuR+AfurwNER8ScURz/XjfcF3N7arDc0GUaDwJGl+SPS2G9FxBMRsS/NXgOcXHVbM+stTYbRfcCxkhZJmknRxrqvvIKkBaXZZcB30/Qm4CxJcyTNAc5KY2bWoxp7Ny0i9ku6iCJEDgLWRcQDki4H+iOiD/gnScuA/cBe4IK07V5JH6IINIDLI2Jv13fCzDpGEdF0DV3TarWiv7+/6TLMpixJWyOiNZFtc7+AbWbThMPIzLLgMDKzLDiMzCwLDiMzy4LDyMyy4DAysyw4jMwsCw4jM8uCw8jMsuAwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLDiMzCwLDiMzy4LDyMyykHt764sl7Ux9026XdFRp2YFS2+u+4duaWW/Jvb31/wKtiHhG0ruBjwHnpmW/ioglXS3azGqTdXvriLgjIp5Js1so+qOZ2RTUC+2th1wI3FKan53aVm+R9KZ2G7m9tVlvaOw0bTwkvQ1oAX9eGj4qIgYlLQY2S9oREQ8N3zYi1gJroWhV1JWCzWzcsm5vDSDpTOAyYFmp1TURMZiedwN3AkvrLNbM6pV7e+ulwNUUQfRYaXyOpFlpeh5wOlC+8G1mPSb39tYfB/4A+JIkgB9GxDLgeOBqSc9RBOoVw96FG9Hux3/JuVffXdMemU1vJxz20klt3+g1o4i4Gbh52Ni/lqbPbLPdt4BX1FudmXWTIqbPNd1WqxX9/f1Nl2E2ZUnaGhGtiWzr20HMLAsOIzPLgsPIzLLgMDKzLDiMzCwLlcJI0ulVxszMJqrqkdGnK46ZmU3IqB96lPQq4NXAfEkXlxa9lOJT02ZmHTHWJ7BnUtyOMQN4SWn8Z8CKuooys+ln1DCKiK8DX5e0PiJ+0KWazGwaqnpv2npJz7tvJCLO6HA9ZjZNVQ2j95emZwNvAfZ3vhwzm64qhVFEbB029E1J99ZQj5lNU5XCSNLc0uwLgJOBl9VSkZlNS1VP07YCAYji9Oxhii/INzPriKqnaYvqLsTMpreqp2mzgfcAf0pxhPQN4HMR8esaazOzaaTq7SDXAydS3ALymTR9w2R/eIX21rMk3ZSW3yPp6NKyS9P4g5LOnmwtZtasqteM/jgiTijN3yFpUt04Kra3vhB4MiKOkbQSuBI4V9IJFN1ETgQOA26T9IcRcWAyNZlZc6oeGX1b0mlDM5JOBSb7ZdJjtrdO89el6Y3A61S0CVkO3BgR+yLiYWAgvZ6Z9aiqYXQy8C1Jj0h6BLgbeKWkHZK2T/BnV2lv/dt1ImI/8DRwSMVtAbe3NusVVU/Tzqm1ihq5vbVZb6gaRh+OiL8rD0i6YfjYOFVpbz20zh5JMyg+aPlExW3NrIdUPU07sTyTguHkSf7sMdtbp/lVaXoFsDmKRm99wMr0btsi4FjAt6eY9bCxvlztUuBfgBdJ+hnFJ7ABniWd+kxUxfbW1wI3SBoA9lIEFmm9LwI7KT4R/l6/k2bW2yp1lJX00Yi4tAv11ModZc3qNZmOslWvGd0i6TXDByPiron8UDOz4aqG0QdK07MpPtOzFfCXq5lZR1S9UfavyvOSjgQ+WUtFZjYtTbSJ4x7g+E4WYmbTW9W79j9Ncbc+FAG2FPh2XUWZ2fRT9ZrRTn7XJ+0pYENEfLOeksxsOhrrc0YzgH8D3gH8MA0vBNZJujciflNzfWY2TYx1zejjwFxgUUScFBEnAYuBg4FP1F2cmU0fY4XRXwJ/HxE/HxqIiJ8B7wbeWGdhZja9jBVGESN8RDvdeuE74M2sY8YKo52Szh8+KOltwPfqKcnMpqOx3k17L/Bfkt5B8YlrgBbwIuDNdRZmZtPLqGEUEYPAqZLO4HdfI3JzRNxee2VmNq1UvR1kM7C55lrMbBqb6O0gZmYd5TAysyw4jMwsCw4jM8tCI2Ekaa6kWyXtSs9zRlhniaS7JT0gabukc0vL1kt6WNK29FjS3T0ws05r6shoDXB7RBwL3J7mh3sGOD8iTqTo2/ZJSQeXln8gIpakx7b6SzazOjUVRuW21dcBbxq+QkR8PyJ2pen/Ax4D5netQjPrqqbC6NCIeDRN/xg4dLSVJZ0CzAQeKg1/JJ2+XSVp1ijbur21WQ+oLYwk3Sbp/hEey8vrpRtx2950K2kBcAPw9oh4Lg1fChwHvJLiK04uabd9RKyNiFZEtObP94GVWa6qftPjuEXEme2WSfqJpAUR8WgKm8farPdS4GvAZRGxpfTaQ0dV+yR9Hnh/B0s3swY0dZpWblu9CvjK8BVSy+svA9dHxMZhyxakZ1Fcb7q/1mrNrHZNhdEVwOsl7QLOTPNIakm6Jq3zVuA1wAUjvIX/BUk7gB3APODD3S3fzDqtUnvrqcLtrc3qNZn21v4EtpllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZyLa9dVrvQOn7r/tK44sk3SNpQNJN6cv7zayH5dzeGuBXpRbWy0rjVwJXRcQxwJPAhfWWa2Z1y7a9dTupPdEZwFD7onFtb2Z5yr299ezUmnqLpKHAOQR4KiL2p/k9wOHtfpDbW5v1hto6ykq6DXj5CIsuK89EREhq1y/pqIgYlLQY2Jx6pT09njoiYi2wFopWRePZ1sy6J+v21hExmJ53S7oTWAr8J3CwpBnp6OgIYLDjO2BmXZVze+s5kmal6XnA6cDOKLpO3gGsGG17M+stObe3Ph7ol/QdivC5IiJ2pmWXABdLGqC4hnRtV6s3s45ze2sz6xi3tzaznucwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLDiMzCwLDiMzy4LDyMyy4DAysyw4jMwsCw4jM8uCw8jMsuAwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLGTb3lrSX5RaW2+T9Ouh3mmS1kt6uLRsSff3wsw6Kdv21hFxx1Bra4oOss8A/1Na5QOl1tfbulK1mdWmV9pbrwBuiYhnaq3KzBqTe3vrISuBDcPGPiJpu6SrhvqrmVnvyr29Nanj7CuATaXhSylCbCZF6+pLgMvbbL8aWA2wcOHCceyBmXVT1u2tk7cCX46I35Ree+ioap+kzwPvH6WOtRSBRavVmj5N4sx6TLbtrUvOY9gpWgowJInietP9NdRoZl2Uc3trJB0NHAl8fdj2X5C0A9gBzAM+3IWazaxGtZ2mjSYingBeN8J4P/DO0vwjwOEjrHdGnfWZWff5E9hmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZaCSMJP2NpAckPSepNcp650h6UNKApDWl8UWS7knjN0ma2Z3KzawuTR0Z3Q/8NXBXuxUkHQR8FngDcAJwnqQT0uIrgasi4hjgSeDCess1s7o1EkYR8d2IeHCM1U4BBiJid0Q8C9wILE+90s4ANqb1rqPonWZmPayRVkUVHQ78qDS/BzgVOAR4KiL2l8af185oSLm9NUUH2qnY8HEe8NOmi6jJVN23qbpffzTRDWsLI0m3AS8fYdFlETFaB9mOKre3ltQfEW2vUfWqqbpfMHX3bSrv10S3rS2MIuLMSb7EIEU32SFHpLEngIMlzUhHR0PjZtbDcn5r/z7g2PTO2UxgJdAXEQHcAaxI660CunakZWb1aOqt/TdL2gO8CviapE1p/DBJNwOko56LgE3Ad4EvRsQD6SUuAS6WNEBxDenaij96bQd3IydTdb9g6u6b92sYFQcaZmbNyvk0zcymEYeRmWVhSofRZG87yZWkuZJulbQrPc9ps94BSdvSo6/bdVY11u9f0qx0289Aug3o6O5XOTEV9u0CSY+X/p3e2USd4yFpnaTH2n1mT4VPpX3eLumkSi8cEVP2ARxP8SGsO4FWm3UOAh4CFgMzge8AJzRd+xj79TFgTZpeA1zZZr1fNF1rhX0Z8/cPvAf4XJpeCdzUdN0d3LcLgM80Xes49+s1wEnA/W2WvxG4BRBwGnBPlded0kdGMYnbTuqvblKWU9wGA71/O0yV3395fzcCr0u3BeWuF/+2xhQRdwF7R1llOXB9FLZQfC5wwVivO6XDqKKRbjtpe3tJJg6NiEfT9I+BQ9usN1tSv6QtknINrCq//9+uE8VHPp6m+EhH7qr+bb0lnc5slHTkCMt7zYT+T+V8b1oludx20mmj7Vd5JiJCUrvPZxwVEYOSFgObJe2IiIc6XatNyleBDRGxT9I/UBwBntFwTY3o+TCK+m47adRo+yXpJ5IWRMSj6fD3sTavMZied0u6E1hKcQ0jJ1V+/0Pr7JE0A3gZxW1BuRtz3yKivB/XUFwP7HUT+j/l07Q2t500XNNY+ihug4E2t8NImiNpVpqeB5wO7OxahdVV+f2X93cFsDnSldLMjblvw66lLKO426DX9QHnp3fVTgOeLl1WaK/pK/M1X/V/M8X56j7gJ8CmNH4YcPOwq//fpzhquKzpuivs1yHA7cAu4DZgbhpvAdek6VcDOyjewdkBXNh03aPsz/N+/8DlwLI0PRv4EjAA3AssbrrmDu7bR4EH0r/THcBxTddcYZ82AI8Cv0n/vy4E3gW8Ky0XxRcjPpT+9kZ8J3v4w7eDmFkWfJpmZllwGJlZFhxGZpYFh5GZZcFhZGZZcBhZV0n6RQ2vebSkv+3061p3OYxsKjgacBj1OIeRNULSayXdmW4O/Z6kLwzdiS/pEUkfk7RD0r2Sjknj6yWtKL3G0FHWFcCfpe8Del/398Y6wWFkTVoK/DNF+/LFFLesDHk6Il4BfAb45Bivswb4RkQsiYiraqnUaucwsibdGxF7IuI5YBvF6daQDaXnV3W7MOs+h5E1aV9p+gC//y0SMcL0ftLfrKQXUHx7ok0RDiPL1bml57vT9CPAyWl6GfDCNP1z4CVdq8xq0fPfZ2RT1hxJ2ymOns5LY/8BfEXSd4D/Bn6ZxrcDB9L4el836k2+a9+yI+kRiq+d+GnTtVj3+DTNzLLgIyMzy4KPjMwsCw4jM8uCw8jMsuAwMrMsOIzMLAv/D+iy2Qfu0NXlAAAAAElFTkSuQmCC\n" + }, + "metadata": { + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "source": [ + "beta_0 = np.zeros((3,1))\n", + "Omega_0 = np.zeros((3,1))\n", + "beta_1 = np.zeros((3,1))\n", + "Omega_1 = np.zeros((3,3))\n", + "beta_2 = np.zeros((1,1))\n", + "Omega_2 = np.zeros((1,3))\n", + "\n", + "# TODO Fill in the values of the beta and Omega matrices for with the n1_theta, n1_phi, n2_theta, and n2_phi parameters \n", + "# that define the composition of the two networks above (see eqn 4.5 for Omega1 and beta1 albeit in different notation)\n", + "# !!! NOTE THAT MATRICES ARE CONVENTIONALLY INDEXED WITH a_11 IN THE TOP LEFT CORNER, BUT NDARRAYS START AT [0,0] SO EVERYTHING IS OFFSET\n", + "# To get you started I've filled in a few:\n", + "beta_0[0,0] = n1_theta_10\n", + "Omega_0[0,0] = n1_theta_11\n", + "beta_1[0,0] = n2_theta_10 + n2_theta_11 * n1_phi_0\n", + "Omega_1[0,0] = n2_theta_11 * n1_phi_1\n", + "\n", + "\n", + "\n", + "# Make sure that input data matrix has different inputs in its columns\n", + "n_data = n1_in.size\n", + "n_dim_in = 1\n", + "n1_in_mat = np.reshape(n1_in,(n_dim_in,n_data))\n", + "\n", + "# This runs the network for ALL of the inputs, x at once so we can draw graph (hence extra np.ones term)\n", + "h1 = ReLU(np.matmul(beta_0,np.ones((1,n_data))) + np.matmul(Omega_0,n1_in_mat))\n", + "h2 = ReLU(np.matmul(beta_1,np.ones((1,n_data))) + np.matmul(Omega_1,h1))\n", + "n1_out = np.matmul(beta_2,np.ones((1,n_data))) + np.matmul(Omega_2,h2)\n", + "\n", + "# Draw the network and check that it looks the same as the non-matrix version\n", + "plot_neural(n1_in, n1_out)" + ], + "metadata": { + "id": "ZB2HTalOE40X", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 283 + }, + "outputId": "a033b2c2-cc2b-48c7-9875-2c9ae7ee5dbe" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAASMAAAEKCAYAAABZgzPTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAATTklEQVR4nO3dfZBddX3H8fdHYhLHqiQkg+EhkAy0PNROAlcepLUWEdFpE62phI4lKDY+0U5ldAjlDx3UCtoZGMWppBgDjBPQtI7rCE2BgDiWABsbE4hiloCaFCUSwAcUSPj2j/NbPCx7d8/u3nPP7+5+XjN37jm/c87d79lkP3POuffcryICM7OmvaTpAszMwGFkZplwGJlZFhxGZpYFh5GZZcFhZGZZaDSMJK2R9Kik+9osl6TPSRqQtFXSCaVlKyTtSI8V3avazOrQ9JHRWuCsEZa/BTg6PVYC/wYgaTbwMeBk4CTgY5Jm1VqpmdWq0TCKiDuBvSOsshS4LgqbgAMlzQPeDNwSEXsj4nHgFkYONTPL3LSmCxjFocBPS/O70li78ReRtJLiqIqXv/zlJx5zzDH1VGpmbN68+RcRMXc82+YeRhMWEauB1QCtViv6+/sbrshs8pL04/Fu2/Q1o9HsBg4vzR+WxtqNm1mPyj2M+oBz07tqpwBPRsQjwAbgTEmz0oXrM9OYmfWoRk/TJK0D3gDMkbSL4h2ylwJExBeBm4C3AgPAU8C707K9kj4B3Jte6tKIGOlCuJllrtEwiohzRlkewIfaLFsDrKmjLjPrvtxP08xsinAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFppub32WpAdS++pVwyy/QtKW9PiRpCdKy/aXlvV1t3Iz67TGvgNb0gHAF4A3UTRhvFdSX0RsH1wnIj5cWv8fgMWll/htRCzqVr1mVq8mj4xOAgYiYmdEPAPcQNHOup1zgHVdqczMuq7JMBpLi+ojgAXAxtLwTEn9kjZJelt9ZZpZN/RKe+vlwPqI2F8aOyIidktaCGyUtC0iHhy6oaSVwEqA+fPnd6daMxuzJo+MxtKiejlDTtEiYnd63gncwQuvJ5XXWx0RrYhozZ07d6I1m1lNmgyje4GjJS2QNJ0icF70rpikY4BZwF2lsVmSZqTpOcBpwPah25pZ72jsNC0i9km6ANgAHACsiYj7JV0K9EfEYDAtB25I3WUHHQtcLek5ikC9rPwunJn1Hr3wb3xya7Va0d/f33QZZpOWpM0R0RrPtv4EtpllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYaDSNJZ0l6QNKApFXDLD9P0h5JW9LjvaVlKyTtSI8V3a3czDqtse4gkg4AvgC8iaKb7L2S+obp8nFjRFwwZNvZwMeAFhDA5rTt410o3cxq0OSR0UnAQETsjIhngBuApRW3fTNwS0TsTQF0C3BWTXWaWRc0GUaHAj8tze9KY0O9Q9JWSeslDXagrbotklZK6pfUv2fPnk7UbWY1yP0C9jeBIyPiTyiOfq4d6wu4vbVZb2gyjHYDh5fmD0tjz4uIxyLi6TR7DXBi1W3NrLc0GUb3AkdLWiBpOkUb677yCpLmlWaXAD9I0xuAMyXNkjQLODONmVmPauzdtIjYJ+kCihA5AFgTEfdLuhToj4g+4B8lLQH2AXuB89K2eyV9giLQAC6NiL1d3wkz6xhFRNM1dE2r1Yr+/v6myzCbtCRtjojWeLbN/QK2mU0RDiMzy4LDyMyy4DAysyw4jMwsCw4jM8uCw8jMsuAwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLDiMzCwLDiMzy4LDyMyy4DAysyw4jMwsCw4jM8tC7u2tL5S0PfVNu03SEaVl+0ttr/uGbmtmvSX39tb/C7Qi4ilJHwA+A5ydlv02IhZ1tWgzq03W7a0j4vaIeCrNbqLoj2Zmk1AvtLcedD5wc2l+ZmpbvUnS29pt5PbWZr2hsdO0sZD0LqAF/Hlp+IiI2C1pIbBR0raIeHDothGxGlgNRauirhRsZmOWdXtrAElnAJcAS0qtromI3el5J3AHsLjOYs2sXrm3t14MXE0RRI+WxmdJmpGm5wCnAeUL32bWY3Jvb/1Z4A+Ar0kC+ElELAGOBa6W9BxFoF425F04M+sxbm9tZh3j9tZm1vMcRmaWBYeRmWXBYWRmWXAYmVkWHEZmloVKYSTptCpjZmbjVfXI6PMVx8zMxmXET2BLOhV4HTBX0oWlRa+k+NS0mVlHjHY7yHSK2zGmAa8ojf8SWFZXUWY29YwYRhHxbeDbktZGxI+7VJOZTUFVb5RdK+lFN7FFxOkdrsfMpqiqYfSR0vRM4B3Avs6XY2ZTVaUwiojNQ4a+K+meGuoxsymqUhhJml2afQlwIvCqWioysymp6mnaZiAAUZyePUTxBflmZh1R9TRtQd2FmNnUVvU0bSbwQeBPKY6QvgN8MSJ+V2NtZjaFVL0d5DrgeIpbQK5K09dP9IdXaG89Q9KNafndko4sLbs4jT8g6c0TrcXMmlX1mtEfR8RxpfnbJU3oC/Artrc+H3g8Io6StBy4HDhb0nEU3USOBw4BbpX0hxGxfyI1mVlzqobR9ySdEhGbACSdDEz0m+2fb2+dXnOwvXU5jJYCH0/T64GrVLQJWQrckPqoPSRpIL3eXSP9wJ17fsPZV4+4ipmN03GHvHJC21c9TTsR+B9JD0t6mOKP/rWStknaOs6fXaW99fPrRMQ+4EngoIrbAi9sb/3ss8+Os1Qzq1vVI6Ozaq2iRkPbW9/4vlMbrshs8vr4BLatGkafjIi/Kw9Iun7o2BhVaW89uM4uSdMoPmj5WMVtzayHVD1NO748k4LhxAn+7FHbW6f5FWl6GbAxiq6TfcDy9G7bAuBowLenmPWw0b5c7WLgn4GXSfolxSewAZ4hnfqMV8X21l8Crk8XqPdSBBZpva9SXOzeB3zI76SZ9bZK7a0lfToiLu5CPbVye2uzek2kvXXVa0Y3S3r90MGIuHM8P9TMbKiqYfTR0vRMis/0bAb85Wpm1hFVb5T9q/K8pMOBK2upyMympPE2cdwFHNvJQsxsaqt61/7nKe7WhyLAFgPfq6soM5t6ql4z2s7v+6Q9AayLiO/WU5KZTUWjfc5oGvAvwHuAn6Th+cAaSfdEhG/2MrOOGO2a0WeB2cCCiDghIk4AFgIHAv9ad3FmNnWMFkZ/Cfx9RPxqcCAifgl8AHhrnYWZ2dQyWhhFDPMR7XTrxegf3TYzq2i0MNou6dyhg5LeBfywnpLMbCoa7d20DwH/Kek9FJ+4BmgBLwPeXmdhZja1jBhGEbEbOFnS6fz+a0Ruiojbaq/MzKaUqreDbAQ21lyLmU1h470dxMysoxxGZpYFh5GZZcFhZGZZaCSMJM2WdIukHel51jDrLJJ0l6T7JW2VdHZp2VpJD0nakh6LursHZtZpTR0ZrQJui4ijgdvS/FBPAedGxPEUfduulHRgaflHI2JRemypv2Qzq1NTYbQUuDZNXwu8begKEfGjiNiRpv8PeBSY27UKzayrmgqjgyPikTT9M+DgkVaWdBIwHXiwNPypdPp2haQZI2z7fHvrPXv2TLhwM6tHbWEk6VZJ9w3zWFpeL92I2/amW0nzgOuBd0fEc2n4YuAY4LUUX3FyUbvtI2J1RLQiojV3rg+szHJV9Zsexywizmi3TNLPJc2LiEdS2DzaZr1XAt8CLomITaXXHjyqelrSl4GPdLB0M2tAU6dp5bbVK4BvDF0htbz+OnBdRKwfsmxeehbF9ab7aq3WzGrXVBhdBrxJ0g7gjDSPpJaka9I67wReD5w3zFv4X5G0DdgGzAE+2d3yzazTKrW3nizc3tqsXhNpb+1PYJtZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWhWzbW6f19pe+/7qvNL5A0t2SBiTdmL6838x6WM7trQF+W2phvaQ0fjlwRUQcBTwOnF9vuWZWt2zbW7eT2hOdDgy2LxrT9maWp9zbW89Mrak3SRoMnIOAJyJiX5rfBRza7ge5vbVZb6ito6ykW4FXD7PokvJMRISkdv2SjoiI3ZIWAhtTr7Qnx1JHRKwGVkPRqmgs25pZ92Td3joidqfnnZLuABYD/wEcKGlaOjo6DNjd8R0ws67Kub31LEkz0vQc4DRgexRdJ28Hlo20vZn1lpzbWx8L9Ev6PkX4XBYR29Oyi4ALJQ1QXEP6UlerN7OOc3trM+sYt7c2s57nMDKzLDiMzCwLDiMzy4LDyMyy4DAysyw4jMwsCw4jM8uCw8jMsuAwMrMsOIzMLAsOIzPLgsPIzLLgMDKzLDiMzCwLDiMzy4LDyMyy4DAysyxk295a0l+UWltvkfS7wd5pktZKeqi0bFH398LMOinb9tYRcftga2uKDrJPAf9dWuWjpdbXW7pStZnVplfaWy8Dbo6Ip2qtyswak3t760HLgXVDxj4laaukKwb7q5lZ78q9vTWp4+xrgA2l4YspQmw6Revqi4BL22y/ElgJMH/+/DHsgZl1U9btrZN3Al+PiGdLrz14VPW0pC8DHxmhjtUUgUWr1Zo6TeLMeky27a1LzmHIKVoKMCSJ4nrTfTXUaGZdlHN7ayQdCRwOfHvI9l+RtA3YBswBPtmFms2sRrWdpo0kIh4D3jjMeD/w3tL8w8Chw6x3ep31mVn3+RPYZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYmVkWHEZmlgWHkZllwWFkZllwGJlZFhxGZpYFh5GZZcFhZGZZcBiZWRYcRmaWBYeRmWXBYWRmWWgkjCT9jaT7JT0nqTXCemdJekDSgKRVpfEFku5O4zdKmt6dys2sLk0dGd0H/DVwZ7sVJB0AfAF4C3AccI6k49Liy4ErIuIo4HHg/HrLNbO6NRJGEfGDiHhglNVOAgYiYmdEPAPcACxNvdJOB9an9a6l6J1mZj2skVZFFR0K/LQ0vws4GTgIeCIi9pXGX9TOaFC5vTVFB9rJ2PBxDvCLpouoyWTdt8m6X3803g1rCyNJtwKvHmbRJRExUgfZjiq3t5bUHxFtr1H1qsm6XzB5920y79d4t60tjCLijAm+xG6KbrKDDktjjwEHSpqWjo4Gx82sh+X81v69wNHpnbPpwHKgLyICuB1YltZbAXTtSMvM6tHUW/tvl7QLOBX4lqQNafwQSTcBpKOeC4ANwA+Ar0bE/eklLgIulDRAcQ3pSxV/9OoO7kZOJut+weTdN+/XECoONMzMmpXzaZqZTSEOIzPLwqQOo4nedpIrSbMl3SJpR3qe1Wa9/ZK2pEdft+usarTfv6QZ6bafgXQb0JHdr3J8KuzbeZL2lP6d3ttEnWMhaY2kR9t9Zk+Fz6V93irphEovHBGT9gEcS/EhrDuAVpt1DgAeBBYC04HvA8c1Xfso+/UZYFWaXgVc3ma9Xzdda4V9GfX3D3wQ+GKaXg7c2HTdHdy384Crmq51jPv1euAE4L42y98K3AwIOAW4u8rrTuojo5jAbSf1VzchSylug4Hevx2myu+/vL/rgTem24Jy14v/t0YVEXcCe0dYZSlwXRQ2UXwucN5orzupw6ii4W47aXt7SSYOjohH0vTPgIPbrDdTUr+kTZJyDawqv//n14niIx9PUnykI3dV/2+9I53OrJd0+DDLe824/qZyvjetklxuO+m0kfarPBMRIand5zOOiIjdkhYCGyVti4gHO12rTcg3gXUR8bSk91EcAZ7ecE2N6PkwivpuO2nUSPsl6eeS5kXEI+nw99E2r7E7Pe+UdAewmOIaRk6q/P4H19klaRrwKorbgnI36r5FRHk/rqG4HtjrxvU35dO0NredNFzTaPooboOBNrfDSJolaUaangOcBmzvWoXVVfn9l/d3GbAx0pXSzI26b0OupSyhuNug1/UB56Z31U4BnixdVmiv6SvzNV/1fzvF+erTwM+BDWn8EOCmIVf/f0Rx1HBJ03VX2K+DgNuAHcCtwOw03gKuSdOvA7ZRvIOzDTi/6bpH2J8X/f6BS4ElaXom8DVgALgHWNh0zR3ct08D96d/p9uBY5quucI+rQMeAZ5Nf1/nA+8H3p+Wi+KLER9M//eGfSd76MO3g5hZFnyaZmZZcBiZWRYcRmaWBYeRmWXBYWRmWXAYWVdJ+nUNr3mkpL/t9OtadzmMbDI4EnAY9TiHkTVC0hsk3ZFuDv2hpK8M3okv6WFJn5G0TdI9ko5K42slLSu9xuBR1mXAn6XvA/pw9/fGOsFhZE1aDPwTRfvyhRS3rAx6MiJeA1wFXDnK66wCvhMRiyLiiloqtdo5jKxJ90TEroh4DthCcbo1aF3p+dRuF2bd5zCyJj1dmt7PC79FIoaZ3kf6PyvpJRTfnmiThMPIcnV26fmuNP0wcGKaXgK8NE3/CnhF1yqzWvT89xnZpDVL0laKo6dz0ti/A9+Q9H3gv4DfpPGtwP40vtbXjXqT79q37Eh6mOJrJ37RdC3WPT5NM7Ms+MjIzLLgIyMzy4LDyMyy4DAysyw4jMwsCw4jM8vC/wOu8sqPbOGqhwAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now let's make a deep network with 3 hidden layers. It will have d_i=4 inputs, d_1=5 neurons in the first layer, d_2=2 neurons in the second layer and d_3=4 neurons in the third layer, and d_o = 1 output. Consults figure 4.6 for guidance." + ], + "metadata": { + "id": "0VANqxH2kyS4" + } + }, + { + "cell_type": "code", + "source": [ + "# define sizes\n", + "D_i=4; D_1=5; D_2=2; D_3=1; D_o=1\n", + "# We'll choose the inputs and parameters of this network randomly using np.random.normal\n", + "# For example, we'll set the input using\n", + "n_data = 10;\n", + "x = np.random.normal(size=(D_i, n_data))\n", + "# TODO initialize the parameters randomly but with the correct sizes\n", + "# Replace the lines below\n", + "beta_0 = np.random.normal(size=(1,1))\n", + "Omega_0 = np.random.normal(size=(1,1))\n", + "beta_1 = np.random.normal(size=(1,1))\n", + "Omega_1 = np.random.normal(size=(1,1))\n", + "beta_2 = np.random.normal(size=(1,1))\n", + "Omega_2 = np.random.normal(size=(1,1))\n", + "beta_3 = np.random.normal(size=(1,1))\n", + "Omega_3 = np.random.normal(size=(1,1))\n", + "\n", + "# If you set the above sizes to the correct values then, the following code will run \n", + "h1 = ReLU(np.matmul(beta_0,np.ones((1,n_data))) + np.matmul(Omega_0,x));\n", + "h2 = ReLU(np.matmul(beta_1,np.ones((1,n_data))) + np.matmul(Omega_1,h1));\n", + "h3 = ReLU(np.matmul(beta_2,np.ones((1,n_data))) + np.matmul(Omega_2,h2));\n", + "y = np.matmul(beta_3,np.ones((1,n_data))) + np.matmul(Omega_3,h3)\n", + "\n", + "if h1.shape[0] is not D_1 or h1.shape[1] is not n_data:\n", + " print(\"h1 is wrong shape\")\n", + "if h2.shape[0] is not D_2 or h1.shape[1] is not n_data:\n", + " print(\"h2 is wrong shape\")\n", + "if h3.shape[0] is not D_3 or h1.shape[1] is not n_data:\n", + " print(\"h3 is wrong shape\")\n", + "if y.shape[0] is not D_o or h1.shape[1] is not n_data:\n", + " print(\"Output is wrong shape\")\n", + "\n", + "# Print the inputs and outputs\n", + "print(x)\n", + "print(y)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 240 + }, + "id": "RdBVAc_Rj22-", + "outputId": "28e1c6c3-2d77-4df9-f887-d3e85535d521" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "error", + "ename": "ValueError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;31m# If you set these to the correct values then, the following code will run\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mh1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mReLU\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta_0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mOmega_0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0mh2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mReLU\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mOmega_1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mh1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mh3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mReLU\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta_2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mOmega_2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mh2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 4 is different from 1)" + ] + } + ] + } + ] +} \ No newline at end of file