Files
udlbook/notebooks/ShallowNN/LinearRegions_Answers.ipynb
2025-11-27 16:25:42 -05:00

287 lines
22 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMaSR3kNIQB0FP4AHM8xnkB",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/notebooks/ShallowNN/LinearRegions_Answers.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"<img src=\"data:image/svg+xml;base64,<?xml version="1.0" encoding="UTF-8"?>
<svg width="764.45" height="63.759" version="1.1" viewBox="0 0 764.45 63.759" xmlns="http://www.w3.org/2000/svg">
 <g transform="matrix(.73548 0 0 .73548 0 3.388)" stroke-width="1.3597">
  <rect x="5e-7" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="10.330567" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="10.330567" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">I </tspan></text>
  <rect x="96" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="106.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="106.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">C </tspan></text>
  <rect x="180" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="190.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="190.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">L </tspan></text>
  <rect x="264" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="274.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="274.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">I </tspan></text>
  <rect x="348" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="358.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="358.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">M </tspan></text>
  <rect x="432" y="5e-7" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
  <text x="442.33057" y="65.761719" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="442.33057" y="65.761719" fill="#ffffff" font-family="Courier" stroke-width="1.3597">B </tspan></text>
  <g transform="translate(2 1.5376)">
   <rect x="624" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="634.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="634.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">T </tspan></text>
   <rect x="708" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="718.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="718.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">R </tspan></text>
   <rect x="792" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="802.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="802.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">E</tspan></text>
   <rect x="876" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="886.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="886.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">E </tspan></text>
   <rect x="960" y="-1.5376" width="77.387" height="77.387" ry="11.35" fill="#4469d8" opacity=".98" stroke-width="0"/>
   <text x="970.33057" y="64.224167" fill="#ffffff" font-family="Courier" font-size="93.483px" stroke-width="1.3597" style="line-height:1.25" xml:space="preserve"><tspan x="970.33057" y="64.224167" fill="#ffffff" font-family="Courier" stroke-width="1.3597">S </tspan></text>
  </g>
  <g transform="matrix(1.0499 0 0 1.0499 -28.092 -.27293)" fill="#4469d8" stroke="#fdffff">
   <rect x="528" y="-1.5376" width="77.387" height="77.387" ry="11.35" opacity=".98" stroke-width="5.18"/>
   <g transform="matrix(.74592 0 0 .74367 530.84 1.6744)" stroke-width="5.2162" featureKey="inlineSymbolFeature-0">
    <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m47.659 81.427c0.358-7.981 1.333-15.917 1.152-23.917-0.01-0.425-0.544-0.843-0.94-0.54-2.356 1.801-4.811 3.219-7.664 4.104-3.649 1.132-7.703-2.328-5.814-5.981 0.758-1.466 2.146-2.708 3.447-3.672 0.467-0.346 0.358-1.176-0.315-1.165-3.154 0.054-10.835 1.149-10.042-4.386 0.481-3.365 6.29-5.458 8.917-6.84 0.333-0.175 0.435-0.73 0.127-0.981-6.663-5.431-3.069-14.647 5.731-12.788 0.272 0.058 0.563-0.033 0.706-0.287 2.235-3.995 4.276-8.063 7.106-11.688-0.356-0.147-0.712-0.294-1.067-0.442 0.294 3.116 2.036 5.269 4.337 7.272 2.459 2.142 7.634 4.27 8.085 7.845 0.481 3.821-6.549 4.356-6.054 7.588 0.33 2.147 1.354 3.423 3.021 4.74 1.052 0.831 1.968 1.405 3.017 2.329 1.818 2.036 1.596 4.223-0.667 6.561-1.486 0.252-2.927 0.138-4.32-0.341-0.556-0.144-0.945 0.435-0.706 0.918 1.412 2.842 3.23 5.449 3.529 8.707 0.821 8.969-7.237 1.748-8.13 0.875-0.813-0.793-1.6-1.561-2.486-2.27-0.623-0.498-1.514 0.38-0.885 0.884 3.399 2.717 6.507 7.782 11.132 4.42 4.323-3.142-0.524-10.114-2.08-13.246-0.235 0.306-0.471 0.612-0.706 0.918 3.9 1.01 8.231 0.447 7.941-4.452-0.117-1.973-1.259-3.644-2.8-4.778-1.468-1.081-6.729-4.234-3.68-6.41 1.261-0.899 2.453-1.826 3.548-2.929 2.294-2.311 1.726-4.94-0.326-7.105-3.535-3.732-9.97-5.682-10.521-11.525-0.044-0.47-0.692-0.921-1.067-0.442-1.267 1.622-6.265 11.724-7.841 11.391-2.234-0.472-4.485 0.06-6.418 1.186-4.105 2.391-3.919 7.903-1.738 11.448 0.122 0.199 1.517 2.084 1.782 1.944-1.682 0.885-3.351 1.737-4.951 2.768-1.664 1.072-4.177 3.262-3.904 5.54 0.671 5.619 7.144 4.902 11.409 4.829-0.105-0.388-0.21-0.776-0.315-1.165-3.56 2.636-8.58 11.381-0.562 12.174 2.34 0.231 4.247-0.259 6.423-1.142 0.883-0.358 1.698-0.845 2.525-1.311 0.775-0.437 1.976-2.122 2.008-0.692 0.166 7.357-0.865 14.714-1.194 22.056-0.036 0.804 1.214 0.801 1.25-2e-3z" fill="#4469d8" stroke="#fdffff" stroke-linejoin="round" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m22.301 83.156c-0.441-6.032-1.072-12.618 0.266-18.564 0.138-0.613-0.578-1.042-1.045-0.608-1.743 1.625-3.443 2.831-5.732 3.604-6.34-3.393-7.913-6.373-4.717-8.939 0.988-0.856 2.034-1.633 3.139-2.329 0.287-0.191 0.397-0.544 0.225-0.855-0.658-1.178-1.392-2.163-2.251-3.191-4.397-5.264-0.382-9.414 4.759-10.875 0.271-0.077 0.455-0.322 0.459-0.603 0.036-2.864 0.313-5.642 1.094-8.407 1.865-6.606 10.255-9.181 13.143-1.487 0.28 0.748 1.489 0.424 1.205-0.332-2.517-6.706-9.574-7.649-13.918-2.003-2.305 2.996-2.61 7.466-2.759 11.084-0.035 0.85-3.839 2.269-4.496 2.694-1.034 0.669-2.219 2.098-2.45 3.312-0.808 4.233 1.103 6.056 3.512 9.323 0.405 0.548-5.327 5.252-5.317 7.279 0.016 3.468 2.455 5.64 5.605 6.645 3.404 1.086 7.127-1.932 9.386-4.037-0.349-0.203-0.697-0.405-1.045-0.608-1.368 6.079-0.762 12.734-0.311 18.896 0.056 0.8 1.306 0.806 1.248 1e-3z" fill="#4469d8" stroke="#fdffff" stroke-linejoin="round" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m21.424 64.741c1.983 2.707 4.981 4.199 8.349 3.637 3.594-0.6 5.191-4.13 5.291-7.411 0.024-0.807-1.226-0.804-1.25 0-0.202 6.67-7.523 8.313-11.31 3.143-0.472-0.643-1.557-0.02-1.08 0.631z" fill="#4469d8" stroke="#fdffff" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m74.661 80.878c2.869-5.406 3.251-12.191 2.679-18.182-0.036-0.381-0.375-0.742-0.791-0.603-1.482 0.496-9.677 1.84-5.634-4.557 0.251-0.397-0.075-0.952-0.54-0.94-4.913 0.123-9.233-0.937-9.57-6.683-0.047-0.801-1.297-0.806-1.25 0 0.201 3.426 1.375 5.828 4.622 7.214 1.514 0.646 3.278 0.7 4.894 0.751-0.658-0.021-0.338 3.074-0.216 3.489 0.625 2.13 4.101 2.773 5.896 2.466 2.606-0.446 1.551 3.288 1.477 5.177-0.15 3.833-0.832 7.82-2.646 11.236-0.378 0.713 0.701 1.345 1.079 0.632z" fill="#4469d8" stroke="#fdffff" stroke-width="5.2162"/>
     </g>
     <g fill="#4469d8" stroke="#fdffff" stroke-width="5.2162">
      <path d="m76.881 63.299c3.341-0.618 7.425-1.372 7.423-5.67 0-1.473-0.141-3.462-1.403-4.486 0.524 0.425 2.703-1.287 3.381-1.885 5.097-4.499 1.607-12.585-4.301-13.85-0.222-0.047 2.216-4.5 2.515-5.157 0.832-1.834 0.614-3.634-8e-3 -5.472-1.133-3.347-6.327-9.06-10.153-9.283-1.411-0.082-2.449-0.077-3.515 0.881-1.212 1.09 0.842 3.98-1.963 2.484-4.82-2.573-5.125 2.25-7.856 4.852-0.584 0.557 0.301 1.439 0.885 0.884 1.199-1.143 0.961-0.736 1.574-2.026 2.202-4.641 4.768-2.589 7.178-1.388 0.334 0.167 0.839 0.047 0.918-0.374 0.208-1.098 0.205-1.025 0.186-2.169 2.787-1.84 5.084-1.596 6.891 0.731 0.745 0.715 1.449 1.469 2.113 2.261 4.874 5.507 2.097 8.833-0.535 13.968-0.228 0.445 0.06 0.897 0.54 0.94 8.368 0.749 8.684 11.983 0.698 13.757-0.432 0.096-0.64 0.75-0.276 1.044 4.99 4.046-0.386 7.969-4.622 8.753-0.794 0.147-0.458 1.351 0.33 1.205z" fill="#4469d8" stroke="#fdffff" stroke-linejoin="round" stroke-width="5.2162"/>
     </g>
    </g>
   </g>
  </g>
 </g>
</svg>
\">"
],
"metadata": {
"id": "aRLqypzfFfJJ"
}
},
{
"cell_type": "markdown",
"source": [
"# **Linear regions answers**\n",
"\n",
"The purpose of this notebook is to compute the maximum possible number of linear regions for shallow networks with different numbers of hidden units and parameters.\n",
"\n",
"You can save a local copy of this notebook in your Google account and work through it in Colab (recommended) or you can download the notebook and run it locally using Jupyter notebook or similar.\n",
"\n",
"Contact me at iclimbtreesmail@gmail.com if you find any mistakes or have any suggestions."
],
"metadata": {
"id": "DCTC8fQ6cp-n"
}
},
{
"cell_type": "code",
"source": [
"# Imports math library\n",
"import numpy as np\n",
"# Imports plotting library\n",
"import matplotlib.pyplot as plt\n",
"# Imports math library\n",
"import math"
],
"metadata": {
"id": "W3C1ZA1gcpq_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The number of regions $N$ created by a shallow neural network with $D_i$ inputs and $D$ hidden units is given by Zaslavsky's formula:\n",
"\n",
"\\begin{equation}N = \\sum_{j=0}^{D_{i}}\\binom{D}{j}=\\sum_{j=0}^{D_{i}} \\frac{D!}{(D-j)!j!} \\end{equation} <br>\n",
"\n"
],
"metadata": {
"id": "TbfanfXBe84L"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4UQ2n0RWcgOb"
},
"outputs": [],
"source": [
"def number_regions(Di, D):\n",
"\n",
" N= 0\n",
" for j in range (Di+1):\n",
" N = N+ math.comb(D, j)\n",
"\n",
" return N"
]
},
{
"cell_type": "code",
"source": [
"# Calculate the number of regions for 2D input (Di=2) and 3 hidden units (D=3) as in figure 3.8j\n",
"N = number_regions(2, 3)\n",
"print(f\"Di=2, D=3, Number of regions = {int(N)}, True value = 7\")"
],
"metadata": {
"id": "AqSUfuJDigN9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Calculate the number of regions for 10D input (Di=2) and 50 hidden units (D=50)\n",
"N = number_regions(10, 50)\n",
"print(f\"Di=10, D=50, Number of regions = {int(N)}, True value = 13432735556\")"
],
"metadata": {
"id": "krNKPV9gjCu-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Calculate the number of regions for 5D input (Di=5) and 50 hidden units (D=50)\n",
"N = number_regions(5, 50)\n",
"print(f\"Di=5, D=40, Number of regions = {int(N)}\")"
],
"metadata": {
"id": "7g4j7OXaSniQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Calculate the number of regions for 10D input (Di=10) and 50 hidden units (D=50)\n",
"N = number_regions(10, 50)\n",
"print(f\"Di=10, D=50, Number of regions = {int(N)}\")"
],
"metadata": {
"id": "OfgVQzHBfgI7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"This works but there may be a complication. If the number of hidden units $D$ is fewer than the number of input dimensions $D_i$ , the formula may fail, depending on how you implemented it (it will work with Math.comb(). In this degenerate case, there are just $2^D$ regions."
],
"metadata": {
"id": "rk1a2LqGkO9u"
}
},
{
"cell_type": "code",
"source": [
"# Test calculation when $D_i < D$\n",
"# If it fails, then rewrite your function number_regions() to cope with this degenerate case\n",
"try:\n",
" D_i = 10\n",
" D = 8\n",
" N = number_regions(D_i, D)\n",
" print(f\"Di={int(D_i)}, D={int(D)}, Number of regions = {int(N)}, True value = 256\")\n",
"except Exception as error:\n",
" print(\"An exception occurred:\", error)\n"
],
"metadata": {
"id": "uq5IeAZTkIMg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Now let's plot the graph relating the number of hidden units to the number of regions\n",
"# for different numbers of input dimensions\n",
"dims = np.array([1,5,10,50,100])\n",
"regions = np.zeros((dims.shape[0], 1000))\n",
"for c_dim in range(dims.shape[0]):\n",
" D_i = dims[c_dim]\n",
" print (f\"Counting regions for {D_i} input dimensions\")\n",
" for D in range(1000):\n",
" regions[c_dim, D] = number_regions(np.min([D_i,D]), D)\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.semilogy(regions[0,:],'k-')\n",
"ax.semilogy(regions[1,:],'b-')\n",
"ax.semilogy(regions[2,:],'m-')\n",
"ax.semilogy(regions[3,:],'c-')\n",
"ax.semilogy(regions[4,:],'y-')\n",
"ax.legend(['$D_i$=1', '$D_i$=5', '$D_i$=10', '$D_i$=50', '$D_i$=100'])\n",
"ax.set_xlabel(\"Number of hidden units, D\")\n",
"ax.set_ylabel(\"Number of regions, N\")\n",
"plt.xlim([0,1000])\n",
"plt.ylim([1e1,1e150])\n",
"plt.show()"
],
"metadata": {
"id": "5XnEOp0Bj_QK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Finally, let's compute and plot the number of regions as a function of the number of parameters\n",
"# First let's write a function that computes the number of parameters as a function of the input dimension and number of hidden units (assuming just one output)\n",
"\n",
"def number_parameters(D_i, D):\n",
" N = D_i * D + D + D + 1\n",
"\n",
" return N ;"
],
"metadata": {
"id": "Pav1OsCnpm6P"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Now let's test the code\n",
"N = number_parameters(10, 8)\n",
"print(f\"Di=10, D=8, Number of parameters = {int(N)}, True value = 97\")"
],
"metadata": {
"id": "VbhDmZ1gwkQj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Now let's plot the graph from figure 3.9a (takes ~1min)\n",
"dims = np.array([1,5,10,50,100])\n",
"regions = np.zeros((dims.shape[0], 200))\n",
"params = np.zeros((dims.shape[0], 200))\n",
"\n",
"# We'll compute the five lines separately this time to make it faster\n",
"for c_dim in range(dims.shape[0]):\n",
" D_i = dims[c_dim]\n",
" print (f\"Counting regions for {D_i} input dimensions\")\n",
" for c_hidden in range(1, 200):\n",
" # Iterate over different ranges of number hidden variables for different input sizes\n",
" D = int(c_hidden * 500 / D_i)\n",
" params[c_dim, c_hidden] = D_i * D +D + D +1\n",
" regions[c_dim, c_hidden] = number_regions(np.min([D_i,D]), D)\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.semilogy(params[0,:], regions[0,:],'k-')\n",
"ax.semilogy(params[1,:], regions[1,:],'b-')\n",
"ax.semilogy(params[2,:], regions[2,:],'m-')\n",
"ax.semilogy(params[3,:], regions[3,:],'c-')\n",
"ax.semilogy(params[4,:], regions[4,:],'y-')\n",
"ax.legend(['$D_i$=1', '$D_i$=5', '$D_i$=10', '$D_i$=50', '$D_i$=100'])\n",
"ax.set_xlabel(\"Number of parameters, D\")\n",
"ax.set_ylabel(\"Number of regions, N\")\n",
"plt.xlim([0,100000])\n",
"plt.ylim([1e1,1e150])\n",
"plt.show()\n"
],
"metadata": {
"id": "AH4nA50Au8-a"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "daA1dAPsLD0h"
},
"execution_count": null,
"outputs": []
}
]
}