Created using Colaboratory

This commit is contained in:
udlbook
2023-10-24 13:10:47 +01:00
parent 0af1ea979c
commit 33a94fb59f

View File

@@ -4,7 +4,7 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"authorship_tag": "ABX9TyPX88BLalmJTle9GSAZMJcz", "authorship_tag": "ABX9TyPDNteVt0SjrR97OiS386NZ",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@@ -306,7 +306,8 @@
"source": [ "source": [
"# Return the negative log likelihood of the data under the model\n", "# Return the negative log likelihood of the data under the model\n",
"def compute_negative_log_likelihood(y_train, mu, sigma):\n", "def compute_negative_log_likelihood(y_train, mu, sigma):\n",
" # TODO -- compute the likelihood of the data -- don't use the likelihood function above -- compute the negative sum of the log probabilities\n", " # TODO -- compute the negative log likelihood of the data without using aproduct\n",
" # In other words, compute minus one times the sum of the log probabilities\n",
" # Equation 5.4 in the notes\n", " # Equation 5.4 in the notes\n",
" # You will need np.sum(), np.log()\n", " # You will need np.sum(), np.log()\n",
" # Replace the line below\n", " # Replace the line below\n",
@@ -352,7 +353,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Return the squared distance between the predicted\n", "# Return the squared distance between the observed data (y_train) and the prediction of the model (y_pred)\n",
"def compute_sum_of_squares(y_train, y_pred):\n", "def compute_sum_of_squares(y_train, y_pred):\n",
" # TODO -- compute the sum of squared distances between the training data and the model prediction\n", " # TODO -- compute the sum of squared distances between the training data and the model prediction\n",
" # Eqn 5.10 in the notes. Make sure that you understand this, and ask questions if you don't\n", " # Eqn 5.10 in the notes. Make sure that you understand this, and ask questions if you don't\n",
@@ -372,9 +373,9 @@
"source": [ "source": [
"# Let's test this again\n", "# Let's test this again\n",
"beta_0, omega_0, beta_1, omega_1 = get_parameters()\n", "beta_0, omega_0, beta_1, omega_1 = get_parameters()\n",
"# Use our neural network to predict the mean of the Gaussian\n", "# Use our neural network to predict the mean of the Gaussian, which is out best prediction of y\n",
"y_pred = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n", "y_pred = mu_pred = shallow_nn(x_train, beta_0, omega_0, beta_1, omega_1)\n",
"# Compute the log likelihood\n", "# Compute the sum of squares\n",
"sum_of_squares = compute_sum_of_squares(y_train, y_pred)\n", "sum_of_squares = compute_sum_of_squares(y_train, y_pred)\n",
"# Let's double check we get the right answer before proceeding\n", "# Let's double check we get the right answer before proceeding\n",
"print(\"Correct answer = %9.9f, Your answer = %9.9f\"%(2.020992572,sum_of_squares))" "print(\"Correct answer = %9.9f, Your answer = %9.9f\"%(2.020992572,sum_of_squares))"
@@ -554,7 +555,7 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"Obviously, to fit the full neural model we would vary all of the 10 parameters of the network in the $\\boldsymbol\\beta_{0},\\boldsymbol\\omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\omega_{1}$ (and maybe $\\sigma$) until we find the combination that have the maximum likelihood / minimum negative log likelihood / least squares.<br><br>\n", "Obviously, to fit the full neural model we would vary all of the 10 parameters of the network in $\\boldsymbol\\beta_{0},\\boldsymbol\\omega_{0},\\boldsymbol\\beta_{1},\\boldsymbol\\omega_{1}$ (and maybe $\\sigma$) until we find the combination that have the maximum likelihood / minimum negative log likelihood / least squares.<br><br>\n",
"\n", "\n",
"Here we just varied one at a time as it is easier to see what is going on. This is known as **coordinate descent**.\n" "Here we just varied one at a time as it is easier to see what is going on. This is known as **coordinate descent**.\n"
], ],