Files
udlbook/Notebooks/Chap20/20_4_Adversarial_Attacks.ipynb
2023-10-25 10:23:10 +01:00

386 lines
14 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyP9amtzXsNWqkmiPUQgxzKV",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap20/20_4_Adversarial_Attacks.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Notebook 20.4: Adversarial attacks**\n",
"\n",
"This notebook builds uses the network for classification of MNIST from Notebook 10.5. The code is adapted from https://nextjournal.com/gkoehler/pytorch-mnist, and uses the fast gradient sign attack of [Goodfellow et al. (2015)](https://arxiv.org/abs/1412.6572). Having trained, the network, we search for adversarial examples -- inputs which look very similar to class A, but are mistakenly classified as class B. We do this by starting with a correctly classified example and perturbing it according to the gradients of the network so that the output changes.\n",
"\n",
"Work through the cells below, running each cell in turn. In various places you will see the words \"TO DO\". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.\n",
"\n",
"Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.\n"
],
"metadata": {
"id": "t9vk9Elugvmi"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torchvision\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"import random"
],
"metadata": {
"id": "YrXWAH7sUWvU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Run this once to load the train and test data straight into a dataloader class\n",
"# that will provide the batches\n",
"batch_size_train = 64\n",
"batch_size_test = 1000\n",
"train_loader = torch.utils.data.DataLoader(\n",
" torchvision.datasets.MNIST('/files/', train=True, download=True,\n",
" transform=torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize(\n",
" (0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=batch_size_train, shuffle=True)\n",
"\n",
"test_loader = torch.utils.data.DataLoader(\n",
" torchvision.datasets.MNIST('/files/', train=False, download=True,\n",
" transform=torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize(\n",
" (0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=batch_size_test, shuffle=True)"
],
"metadata": {
"id": "wScBGXXFVadm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Let's draw some of the training data\n",
"examples = enumerate(test_loader)\n",
"batch_idx, (example_data, example_targets) = next(examples)\n",
"\n",
"fig = plt.figure()\n",
"for i in range(6):\n",
" plt.subplot(2,3,i+1)\n",
" plt.tight_layout()\n",
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n",
" plt.title(\"Ground Truth: {}\".format(example_targets[i]))\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"plt.show()"
],
"metadata": {
"id": "8bKADvLHbiV5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Define the network. This is a more typical way to define a network than the sequential structure. We define a class for the network, and define the parameters in the constructor. Then we use a function called forward to actually run the network. It's easy to see how you might use residual connections in this format."
],
"metadata": {
"id": "_sFvRDGrl4qe"
}
},
{
"cell_type": "code",
"source": [
"from os import X_OK\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
" self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
" self.drop = nn.Dropout2d()\n",
" self.fc1 = nn.Linear(320, 50)\n",
" self.fc2 = nn.Linear(50, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = F.max_pool2d(x,2)\n",
" x = F.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.drop(x)\n",
" x = F.max_pool2d(x,2)\n",
" x = F.relu(x)\n",
" x = x.flatten(1)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.fc2(x)\n",
" x = F.log_softmax(x)\n",
" return x"
],
"metadata": {
"id": "EQkvw2KOPVl7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# He initialization of weights\n",
"def weights_init(layer_in):\n",
" if isinstance(layer_in, nn.Linear):\n",
" nn.init.kaiming_uniform_(layer_in.weight)\n",
" layer_in.bias.data.fill_(0.0)"
],
"metadata": {
"id": "qWZtkCZcU_dg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Create network\n",
"model = Net()\n",
"# Initialize model weights\n",
"model.apply(weights_init)\n",
"# Define optimizer\n",
"optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)"
],
"metadata": {
"id": "FslroPJJffrh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Main training routine\n",
"def train(epoch):\n",
" model.train()\n",
" # Get each\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" # Store results\n",
" if batch_idx % 10 == 0:\n",
" print('Train Epoch: {} [{}/{}]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()))"
],
"metadata": {
"id": "xKQd9PzkQ766"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Run on test data\n",
"def test():\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in test_loader:\n",
" output = model(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
" test_loss /= len(test_loader.dataset)\n",
" print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(test_loader.dataset),\n",
" 100. * correct / len(test_loader.dataset)))"
],
"metadata": {
"id": "Byn-f7qWRLxX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Get initial performance\n",
"test()\n",
"# Train for three epochs\n",
"n_epochs = 3\n",
"for epoch in range(1, n_epochs + 1):\n",
" train(epoch)\n",
" test()"
],
"metadata": {
"id": "YgLaex1pfhqz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Run network on data we got before and show predictions\n",
"output = model(example_data)\n",
"\n",
"fig = plt.figure()\n",
"for i in range(6):\n",
" plt.subplot(2,3,i+1)\n",
" plt.tight_layout()\n",
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n",
" plt.title(\"Prediction: {}\".format(\n",
" output.data.max(1, keepdim=True)[1][i].item()))\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"plt.show()"
],
"metadata": {
"id": "o7fRUAy9Se1B"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"This is the code that does the adversarial attack. It is adapted from [here](https://pytorch.org/tutorials/beginner/fgsm_tutorial.html). It is an example of the fast gradient sign method (FGSM), which modifies the data by\n",
"\n",
"\n",
"\n",
"* Calculating the derivative $\\partial L/\\partial \\mathbf{x}$ of the loss $L$ with respect to the input data $\\mathbf{x}$.\n",
"* Finds the sign of the gradient at each point (making a tensor the same size as $\\mathbf{x}$ with a one where it was positive and minus one where it was negative. \n",
"* Multiplying this vector by $\\epsilon$ and adding it back to the original data\n",
"\n",
"\n"
],
"metadata": {
"id": "EabuoMdP32Hd"
}
},
{
"cell_type": "code",
"source": [
"# FGSM attack code.\n",
"def fgsm_attack(x, epsilon, dLdx):\n",
" # TODO -- write this function\n",
" # Get the sign of the gradient\n",
" # Add epsilon times the size of gradient to x\n",
" # Replace this line\n",
" x_modified = torch.zeros_like(x)\n",
"\n",
" # Return the perturbed image\n",
" return x_modified"
],
"metadata": {
"id": "gAX7tnld46q1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"no_examples = 3\n",
"epsilon = 0.5\n",
"for i in range(no_examples):\n",
" # Reset gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # Get the i'th data example\n",
" x = example_data[i,:,:,:]\n",
" # Add an extra dimension back to the beginning\n",
" x= x[None, :,:,:]\n",
" x.requires_grad = True\n",
" # Get the i'th target\n",
" y = torch.ones(1, dtype=torch.long) * example_targets[i]\n",
"\n",
" # Run the model\n",
" output = model(x)\n",
" # Compute the loss\n",
" loss = F.nll_loss(output, y)\n",
" # Back propagate\n",
" loss.backward()\n",
"\n",
" # Collect ``datagrad``\n",
" dLdx = x.grad.data\n",
"\n",
" # Call FGSM Attack\n",
" x_prime = fgsm_attack(x, epsilon, dLdx)\n",
"\n",
" # Re-classify the perturbed image\n",
" output_prime = model(x_prime)\n",
"\n",
" x = x.detach().numpy()\n",
" fig = plt.figure()\n",
" plt.subplot(1,2,1)\n",
" plt.tight_layout()\n",
" plt.imshow(x[0][0], cmap='gray', interpolation='none')\n",
" plt.title(\"Original Prediction: {}\".format(\n",
" output.data.max(1, keepdim=True)[1][0].item()))\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
" plt.subplot(1,2,2)\n",
" plt.tight_layout()\n",
" plt.imshow(x_prime[0][0].detach().numpy(), cmap='gray', interpolation='none')\n",
" plt.title(\"Perturbed Prediction: {}\".format(\n",
" output_prime.data.max(1, keepdim=True)[1][0].item()))\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
"plt.show()"
],
"metadata": {
"id": "AuNTYWboufbm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Although we have only added a small amount of noise, the model is fooled into thinking that these images come from a different class.\n",
"\n",
"TODO -- Modify the attack so that it iteratively perturbs the data. i.e., so we take a small step epsilon, then re-calculate the gradient and take another small step according to the new gradient signs."
],
"metadata": {
"id": "vFXWK826HPQ8"
}
}
]
}