Created using Colaboratory
This commit is contained in:
419
Notebooks/Chap15/15_1_GAN_Toy_Example.ipynb
Normal file
419
Notebooks/Chap15/15_1_GAN_Toy_Example.ipynb
Normal file
@@ -0,0 +1,419 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"authorship_tag": "ABX9TyM0StKV3FIZ3MZqfflqC0Rv",
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap15/15_1_GAN_Toy_Example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# **Notebook 15.1: GAN Toy example**\n",
|
||||
"\n",
|
||||
"This notebook investigates the GAN toy example as illustred in figure 15.1 in the book.\n",
|
||||
"\n",
|
||||
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
|
||||
"\n",
|
||||
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "t9vk9Elugvmi"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OLComQyvCIJ7"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Get a batch of real data. Our goal is to make data that looks like this.\n",
|
||||
"def get_real_data_batch(n_sample):\n",
|
||||
" np.random.seed(0)\n",
|
||||
" x_true = np.random.normal(size=(1,n_sample)) + 7.5\n",
|
||||
" return x_true"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "y_OkVWmam4Qx"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Define our generator. This takes a standard normally-distributed latent variable $z$ and adds a scalar $\\theta$ to this, where $\\theta$ is the single parameter of this generative model according to:\n",
|
||||
"\n",
|
||||
"\\begin{equation}\n",
|
||||
"x_i = z_i + \\theta.\n",
|
||||
"\\end{equation}\n",
|
||||
"\n",
|
||||
"Obviously this model can generate the family of Gaussian distributions with unit variance, but different means."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "RFpL0uCXoTpV"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# This is our generator -- takes the single parameter theta\n",
|
||||
"# of the generative model and generates n samples\n",
|
||||
"def generator(z, theta):\n",
|
||||
" x_gen = z + theta\n",
|
||||
" return x_gen"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "OtLQvf3Enfyw"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Now, we define our disriminator. This is a simple logistic regression model (1D linear model passed through sigmoid) that returns the probability that the data is real"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Xrzd8aehYAYR"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Define our discriminative model\n",
|
||||
"\n",
|
||||
"# Logistic sigmoid, maps from [-infty,infty] to [0,1]\n",
|
||||
"def sig(data_in):\n",
|
||||
" return 1.0 / (1.0+np.exp(-data_in))\n",
|
||||
"\n",
|
||||
"# Discriminator computes y\n",
|
||||
"def discriminator(x, phi0, phi1):\n",
|
||||
" return sig(phi0 + phi1 * x)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "vHBgAFZMsnaC"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Draws a figure like Figure 15.1a\n",
|
||||
"def draw_data_model(x_real, x_syn, phi0=None, phi1=None):\n",
|
||||
" fix, ax = plt.subplots();\n",
|
||||
"\n",
|
||||
" for x in x_syn:\n",
|
||||
" ax.plot([x,x],[0,0.33],color='#f47a60')\n",
|
||||
" for x in x_real:\n",
|
||||
" ax.plot([x,x],[0,0.33],color='#7fe7dc')\n",
|
||||
"\n",
|
||||
" if phi0 is not None:\n",
|
||||
" x_model = np.arange(0,10,0.01)\n",
|
||||
" y_model = discriminator(x_model, phi0, phi1)\n",
|
||||
" ax.plot(x_model, y_model,color='#dddddd')\n",
|
||||
" ax.set_xlim([0,10])\n",
|
||||
" ax.set_ylim([0,1])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" plt.show()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "V1FiDBhepcQJ"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Get data batch\n",
|
||||
"x_real = get_real_data_batch(10)\n",
|
||||
"\n",
|
||||
"# Initialize generator and synthesize a batch of examples\n",
|
||||
"theta = 3.0\n",
|
||||
"np.random.seed(1)\n",
|
||||
"z = np.random.normal(size=(1,10))\n",
|
||||
"x_syn = generator(z, theta)\n",
|
||||
"\n",
|
||||
"# Initialize discriminator model\n",
|
||||
"phi0 = -2\n",
|
||||
"phi1 = 1\n",
|
||||
"\n",
|
||||
"draw_data_model(x_real, x_syn, phi0, phi1)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "U8pFb497x36n"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"You can see that the synthesized (orange) samples don't look much like the real (cyan) ones, and the initial model to discriminate them (gray line represents probability of being real) is pretty bad as well.\n",
|
||||
"\n",
|
||||
"Let's deal with the discriminator first. Let's define the loss"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "SNDV1G5PYhcQ"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Discriminator loss\n",
|
||||
"def compute_discriminator_loss(x_real, x_syn, phi0, phi1):\n",
|
||||
"\n",
|
||||
" # TODO -- compute the loss for the discriminator\n",
|
||||
" # Run the real data and the synthetic data through the discriminator\n",
|
||||
" # Then use the standard binary cross entropy loss with the y=1 for the real samples\n",
|
||||
" # and y=0 for the synthesized ones.\n",
|
||||
" # Replace this line\n",
|
||||
" loss = 0.0\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" return loss"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Bc3VwCabYcfg"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Test the loss\n",
|
||||
"loss = compute_discriminator_loss(x_real, x_syn, phi0, phi1)\n",
|
||||
"print(\"True Loss = 13.814757170851447, Your loss=\", loss )"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "MiqM3GXSbn0z"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Gradient of loss (cheating, using finite differences)\n",
|
||||
"def compute_discriminator_gradient(x_real, x_syn, phi0, phi1):\n",
|
||||
" delta = 0.0001;\n",
|
||||
" loss1 = compute_discriminator_loss(x_real, x_syn, phi0, phi1)\n",
|
||||
" loss2 = compute_discriminator_loss(x_real, x_syn, phi0+delta, phi1)\n",
|
||||
" loss3 = compute_discriminator_loss(x_real, x_syn, phi0, phi1+delta)\n",
|
||||
" dl_dphi0 = (loss2-loss1) / delta\n",
|
||||
" dl_dphi1 = (loss3-loss1) / delta\n",
|
||||
"\n",
|
||||
" return dl_dphi0, dl_dphi1\n",
|
||||
"\n",
|
||||
"# This routine performs gradient descent with the discriminator\n",
|
||||
"def update_discriminator(x_real, x_syn, n_iter, phi0, phi1):\n",
|
||||
"\n",
|
||||
" # Define learning rate\n",
|
||||
" alpha = 0.01\n",
|
||||
"\n",
|
||||
" # Get derivatives\n",
|
||||
" print(\"Initial discriminator loss = \", compute_discriminator_loss(x_real, x_syn, phi0, phi1))\n",
|
||||
" for iter in range(n_iter):\n",
|
||||
" # Get gradient\n",
|
||||
" dl_dphi0, dl_dphi1 = compute_discriminator_gradient(x_real, x_syn, phi0, phi1)\n",
|
||||
" # Take a gradient step downhill\n",
|
||||
" phi0 = phi0 - alpha * dl_dphi0 ;\n",
|
||||
" phi1 = phi1 - alpha * dl_dphi1 ;\n",
|
||||
"\n",
|
||||
" print(\"Final Discriminator Loss= \", compute_discriminator_loss(x_real, x_syn, phi0, phi1))\n",
|
||||
"\n",
|
||||
" return phi0, phi1"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "zAxUPo3p0CIW"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Let's update the discriminator (sigmoid curve)\n",
|
||||
"n_iter = 100\n",
|
||||
"print(\"Initial parameters (phi0,phi1)\", phi0, phi1)\n",
|
||||
"phi0, phi1 = update_discriminator(x_real, x_syn, n_iter, phi0, phi1)\n",
|
||||
"print(\"Final parameters (phi0,phi1)\", phi0, phi1)\n",
|
||||
"draw_data_model(x_real, x_syn, phi0, phi1)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "FE_DeweeAbMc"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Now let's update the generator"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "pRv9myh0d3Xm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def compute_generator_loss(z, theta, phi0, phi1):\n",
|
||||
" # TODO -- Run the generator on the latent variables z with the parameters theta\n",
|
||||
" # to generate new data x_syn\n",
|
||||
" # Then run the discriminator on the new data to get the probability of being real\n",
|
||||
" # The loss is the total negative log probability of being synthesized (i.e. of not being real)\n",
|
||||
" # Replace this code\n",
|
||||
" loss = 1\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" return loss"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "5uiLrFBvJFAr"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Test generator loss to check you have it correct\n",
|
||||
"loss = compute_generator_loss(z, theta, -2, 1)\n",
|
||||
"print(\"True Loss = 13.78437035945412, Your loss=\", loss )"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "cqnU3dGPd6NK"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def compute_generator_gradient(z, theta, phi0, phi1):\n",
|
||||
" delta = 0.0001\n",
|
||||
" loss1 = compute_generator_loss(z,theta, phi0, phi1) ;\n",
|
||||
" loss2 = compute_generator_loss(z,theta+delta, phi0, phi1) ;\n",
|
||||
" dl_dtheta = (loss2-loss1)/ delta\n",
|
||||
" return dl_dtheta\n",
|
||||
"\n",
|
||||
"def update_generator(z, theta, n_iter, phi0, phi1):\n",
|
||||
" # Define learning rate\n",
|
||||
" alpha = 0.02\n",
|
||||
"\n",
|
||||
" # Get derivatives\n",
|
||||
" print(\"Initial generator loss = \", compute_generator_loss(z, theta, phi0, phi1))\n",
|
||||
" for iter in range(n_iter):\n",
|
||||
" # Get gradient\n",
|
||||
" dl_dtheta = compute_generator_gradient(x_real, x_syn, phi0, phi1)\n",
|
||||
" # Take a gradient step (uphill, since we are trying to make synthesized data less well classified by discriminator)\n",
|
||||
" theta = theta + alpha * dl_dtheta ;\n",
|
||||
"\n",
|
||||
" print(\"Final generator loss = \", compute_generator_loss(z, theta, phi0, phi1))\n",
|
||||
" return theta\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "P1Lqy922dqal"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"n_iter = 10\n",
|
||||
"theta = 3.0\n",
|
||||
"print(\"Theta before\", theta)\n",
|
||||
"theta = update_generator(z, theta, n_iter, phi0, phi1)\n",
|
||||
"print(\"Theta after\", theta)\n",
|
||||
"\n",
|
||||
"x_syn = generator(z,theta)\n",
|
||||
"draw_data_model(x_real, x_syn, phi0, phi1)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Q6kUkMO1P8V0"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Now let's define a full GAN loop\n",
|
||||
"\n",
|
||||
"# Initialize the parameters\n",
|
||||
"theta = 3\n",
|
||||
"phi0 = -2\n",
|
||||
"phi1 = 1\n",
|
||||
"\n",
|
||||
"# Number of iterations for updating generator and discriminator\n",
|
||||
"n_iter_discrim = 300\n",
|
||||
"n_iter_gen = 3\n",
|
||||
"\n",
|
||||
"print(\"Final parameters (phi0,phi1)\", phi0, phi1)\n",
|
||||
"for c_gan_iter in range(5):\n",
|
||||
"\n",
|
||||
" # Run generator to product syntehsized data\n",
|
||||
" x_syn = generator(z, theta)\n",
|
||||
" draw_data_model(x_real, x_syn, phi0, phi1)\n",
|
||||
"\n",
|
||||
" # Update the discriminator\n",
|
||||
" print(\"Updating discriminator\")\n",
|
||||
" phi0, phi1 = update_discriminator(x_real, x_syn, n_iter_discrim, phi0, phi1)\n",
|
||||
" draw_data_model(x_real, x_syn, phi0, phi1)\n",
|
||||
"\n",
|
||||
" # Update the generator\n",
|
||||
" print(\"Updating generator\")\n",
|
||||
" theta = update_generator(z, theta, n_iter_gen, phi0, phi1)\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "pcbdK2agTO-y"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"You can see that the synthesized data (orange) is becoming closer to the true data (cyan). However, this is extremely unstable -- as you will find if you mess around with the number of iterations of each optimization and the total iterations overall."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "loMx0TQUgBs7"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user