Rename Training_I.ipynb to CM20315/CM20315_Training_I.ipynb

This commit is contained in:
udlbook
2023-07-23 18:41:48 -04:00
committed by GitHub
parent e82701b1fd
commit 80fff5c1ce

View File

@@ -0,0 +1,190 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMUqLjI8VIQXHOYx0I37OmR",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Training_I.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Training\n",
"\n",
"We now have a model and a loss function which we can use to judge how good that model is. It's time to put the \"learning\" into machine learning.\n",
"\n",
"Learning involves finding the parameters that minimize the loss. That might seems like it's not too hard, but modern models might have billions of parameters. There's an exponential number of possible parameter combinations, and there's no way we can make any progress with exhaustive search.\n",
"\n",
"We'll build this up in stages. In this practical, we'll just consider 1D search using a bracketing approach. In part II, we'll extend to fitting the linear regression model (which has a convex loss function). Then in part III, we'll consider non-convex loss functions\n"
],
"metadata": {
"id": "el8l05WQEO46"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xhmIOLiZELV_"
},
"outputs": [],
"source": [
"# import libraries\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"source": [
"# Let's create a simple 1D function\n",
"def loss_function(phi):\n",
" return 1- 0.5 * np.exp(-(phi-0.65)*(phi-0.65)/0.1) - 0.45 *np.exp(-(phi-0.35)*(phi-0.35)/0.02)\n",
"\n",
"def draw_function(loss_function,a=None, b=None, c=None, d=None):\n",
" # Plot the function \n",
" phi_plot = np.arange(0,1,0.01);\n",
" fig,ax = plt.subplots()\n",
" ax.plot(phi_plot,loss_function(phi_plot),'r-')\n",
" ax.set_xlim(0,1); ax.set_ylim(0,1)\n",
" ax.set_xlabel('$\\phi$'); ax.set_ylabel('$L[\\phi]$')\n",
" if a is not None and b is not None and c is not None and d is not None:\n",
" plt.axvspan(a, d, facecolor='k', alpha=0.2)\n",
" ax.plot([a,a],[0,1],'b-')\n",
" ax.plot([b,b],[0,1],'b-')\n",
" ax.plot([c,c],[0,1],'b-')\n",
" ax.plot([d,d],[0,1],'b-')\n",
" plt.show()\n"
],
"metadata": {
"id": "qFRe9POHF2le"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Draw this function\n",
"draw_function(loss_function)"
],
"metadata": {
"id": "TXx1Tpd1Tl-I"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now lets create a line search procedure to find the minimum in the range 0,1"
],
"metadata": {
"id": "QU5mdGvpTtEG"
}
},
{
"cell_type": "code",
"source": [
"def line_search(loss_function, thresh=.0001, max_iter = 10, draw_flag = False):\n",
"\n",
" # Initialize four points along the rnage we are going to search\n",
" a = 0\n",
" b = 0.33\n",
" c = 0.66\n",
" d = 1.0\n",
" n_iter =0;\n",
" \n",
" # While we haven't found the minimum closely enough\n",
" while np.abs(b-c) > thresh and n_iter < max_iter:\n",
" # Increment iteration counter (just to prevent an infinite loop)\n",
" n_iter = n_iter+1\n",
" # Calculate all four points\n",
" lossa = loss_function(a)\n",
" lossb = loss_function(b)\n",
" lossc = loss_function(c)\n",
" lossd = loss_function(d)\n",
"\n",
" if draw_flag:\n",
" draw_function(loss_function, a,b,c,d)\n",
"\n",
" print('Iter %d, a=%3.3f, b=%3.3f, c=%3.3f, d=%3.3f'%(n_iter, a,b,c,d))\n",
"\n",
" # Rule #1 If point A is less than points B, C, and D then halve values of B,C, and D\n",
" # i.e. bring them closer to the original point\n",
" # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n",
" if (0):\n",
" continue;\n",
"\n",
" # Rule #2 If point b is less than point c then\n",
" # then point d becomes point c, and\n",
" # point b becomes 1/3 between a and new d\n",
" # point c beocome 2/3 between a and new d \n",
" # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n",
" if (0):\n",
" continue;\n",
"\n",
" # Rule #3 If point c is less than point b then\n",
" # then point a becomes point b, and\n",
" # point b becomes 1/3 between new a and d\n",
" # point c beocome 2/3 between new a and d \n",
" # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n",
" if(0):\n",
" continue\n",
"\n",
"\n",
" # TODO -- FINAL SOLUTION IS AVERAGE OF B and C\n",
" # REPLACE THIS LINE\n",
" soln = 1\n",
" \n",
" return soln"
],
"metadata": {
"id": "K-NTHpAAHlCl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"soln = line_search(loss_function, draw_flag=True)\n",
"print('Soln = %3.3f, loss = %3.3f'%(soln,loss_function(soln)))"
],
"metadata": {
"id": "YVq6rmaWRD2M"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "tOLd0gtdRLLS"
},
"execution_count": null,
"outputs": []
}
]
}