{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# **Notebook 6.1: Line search**\n", "\n", "This notebook investigates how to find the minimum of a 1D function using line search as described in Figure 6.10.\n", "\n", "Work through the cells below, running each cell in turn. In various places you will see the words \"TODO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n", "\n", "Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n" ], "metadata": { "id": "el8l05WQEO46" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xhmIOLiZELV_" }, "outputs": [], "source": [ "# import libraries\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "source": [ "# Let's create a simple 1D function\n", "def loss_function(phi):\n", " return 1- 0.5 * np.exp(-(phi-0.65)*(phi-0.65)/0.1) - 0.45 *np.exp(-(phi-0.35)*(phi-0.35)/0.02)\n", "\n", "def draw_function(loss_function,a=None, b=None, c=None, d=None):\n", " # Plot the function\n", " phi_plot = np.arange(0,1,0.01);\n", " fig,ax = plt.subplots()\n", " ax.plot(phi_plot,loss_function(phi_plot),'r-')\n", " ax.set_xlim(0,1); ax.set_ylim(0,1)\n", " ax.set_xlabel(r'$\\phi$'); ax.set_ylabel(r'$L[\\phi]$')\n", " if a is not None and b is not None and c is not None and d is not None:\n", " plt.axvspan(a, d, facecolor='k', alpha=0.2)\n", " ax.plot([a,a],[0,1],'b-')\n", " ax.plot([b,b],[0,1],'b-')\n", " ax.plot([c,c],[0,1],'b-')\n", " ax.plot([d,d],[0,1],'b-')\n", " plt.show()\n" ], "metadata": { "id": "qFRe9POHF2le" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Draw this function\n", "draw_function(loss_function)" ], "metadata": { "id": "TXx1Tpd1Tl-I" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now lets create a line search procedure to find the minimum in the range 0,1" ], "metadata": { "id": "QU5mdGvpTtEG" } }, { "cell_type": "code", "source": [ "def line_search(loss_function, thresh=.0001, max_iter = 10, draw_flag = False):\n", "\n", " # Initialize four points along the range we are going to search\n", " a = 0\n", " b = 0.33\n", " c = 0.66\n", " d = 1.0\n", " n_iter = 0\n", "\n", " # While we haven't found the minimum closely enough\n", " while np.abs(b-c) > thresh and n_iter < max_iter:\n", " # Increment iteration counter (just to prevent an infinite loop)\n", " n_iter = n_iter+1\n", "\n", " # Calculate all four points\n", " lossa = loss_function(a)\n", " lossb = loss_function(b)\n", " lossc = loss_function(c)\n", " lossd = loss_function(d)\n", "\n", " if draw_flag:\n", " draw_function(loss_function, a,b,c,d)\n", "\n", " print('Iter %d, a=%3.3f, b=%3.3f, c=%3.3f, d=%3.3f'%(n_iter, a,b,c,d))\n", "\n", " # Rule #1 If the HEIGHT at point A is less than the HEIGHT at points B, C, and D then move them to they are half\n", " # as far from A as they start\n", " # i.e. bring them closer to the original point\n", " # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n", " if (0):\n", " continue;\n", "\n", "\n", " # Rule #2 If the HEIGHT at point b is less than the HEIGHT at point c then\n", " # point d becomes point c, and\n", " # point b becomes 1/3 between a and new d\n", " # point c becomes 2/3 between a and new d\n", " # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n", " if (0):\n", " continue;\n", "\n", " # Rule #3 If the HEIGHT at point c is less than the HEIGHT at point b then\n", " # point a becomes point b, and\n", " # point b becomes 1/3 between new a and d\n", " # point c becomes 2/3 between new a and d\n", " # TODO REPLACE THE BLOCK OF CODE BELOW WITH THIS RULE\n", " if(0):\n", " continue\n", "\n", " # TODO -- FINAL SOLUTION IS AVERAGE OF B and C\n", " # REPLACE THIS LINE\n", " soln = 1\n", "\n", "\n", " return soln" ], "metadata": { "id": "K-NTHpAAHlCl" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "soln = line_search(loss_function, draw_flag=True)\n", "print('Soln = %3.3f, loss = %3.3f'%(soln,loss_function(soln)))" ], "metadata": { "id": "YVq6rmaWRD2M" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "tOLd0gtdRLLS" }, "execution_count": null, "outputs": [] } ] }