Update CM20315_Training_II.ipynb

This commit is contained in:
Pietro Monticone
2023-11-30 16:44:42 +01:00
parent 685d910bbc
commit 9b13823ca8

View File

@@ -114,7 +114,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Initialize the parmaeters and draw the model\n", "# Initialize the parameters and draw the model\n",
"phi = np.zeros((2,1))\n", "phi = np.zeros((2,1))\n",
"phi[0] = 0.6 # Intercept\n", "phi[0] = 0.6 # Intercept\n",
"phi[1] = -0.2 # Slope\n", "phi[1] = -0.2 # Slope\n",
@@ -314,7 +314,7 @@
" return compute_loss(data[0,:], data[1,:], model, phi_start+ gradient * dist_prop)\n", " return compute_loss(data[0,:], data[1,:], model, phi_start+ gradient * dist_prop)\n",
"\n", "\n",
"def line_search(data, model, phi, gradient, thresh=.00001, max_dist = 0.1, max_iter = 15, verbose=False):\n", "def line_search(data, model, phi, gradient, thresh=.00001, max_dist = 0.1, max_iter = 15, verbose=False):\n",
" # Initialize four points along the rnage we are going to search\n", " # Initialize four points along the range we are going to search\n",
" a = 0\n", " a = 0\n",
" b = 0.33 * max_dist\n", " b = 0.33 * max_dist\n",
" c = 0.66 * max_dist\n", " c = 0.66 * max_dist\n",
@@ -345,7 +345,7 @@
" # Rule #2 If point b is less than point c then\n", " # Rule #2 If point b is less than point c then\n",
" # then point d becomes point c, and\n", " # then point d becomes point c, and\n",
" # point b becomes 1/3 between a and new d\n", " # point b becomes 1/3 between a and new d\n",
" # point c beocome 2/3 between a and new d \n", " # point c becomes 2/3 between a and new d \n",
" if lossb < lossc:\n", " if lossb < lossc:\n",
" d = c\n", " d = c\n",
" b = a+ (d-a)/3\n", " b = a+ (d-a)/3\n",
@@ -355,7 +355,7 @@
" # Rule #2 If point c is less than point b then\n", " # Rule #2 If point c is less than point b then\n",
" # then point a becomes point b, and\n", " # then point a becomes point b, and\n",
" # point b becomes 1/3 between new a and d\n", " # point b becomes 1/3 between new a and d\n",
" # point c beocome 2/3 between new a and d \n", " # point c becomes 2/3 between new a and d \n",
" a = b\n", " a = b\n",
" b = a+ (d-a)/3\n", " b = a+ (d-a)/3\n",
" c = a+ 2*(d-a)/3\n", " c = a+ 2*(d-a)/3\n",