{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Copyright 2022 Google LLC\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Adversarial training\n", "\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jaxopt/blob/main/docs/notebooks/deep_learning/adversarial_training.ipynb)\n", "\n", "\n", "The following code trains a convolutional neural network (CNN) to be robust\n", "with respect to the projected gradient descent (PGD) method.\n", "\n", "The Projected Gradient Descent Method (PGD) is a simple yet effective method to\n", "generate adversarial images. At each iteration, it adds a small perturbation\n", "in the direction of the sign of the gradient with respect to the input followed\n", "by a projection onto the infinity ball. The gradient sign ensures this\n", "perturbation locally maximizes the objective, while the projection ensures this\n", "perturbation stays on the boundary of the infinity ball.\n", "\n", "## References\n", "\n", " Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining\n", " and harnessing adversarial examples.\", https://arxiv.org/abs/1412.6572\n", "\n", " Madry, Aleksander, et al. \"Towards deep learning models resistant to\n", " adversarial attacks.\", https://arxiv.org/pdf/1706.06083.pdf" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install jaxopt flax" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import datetime\n", "import collections\n", "\n", "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")\n", "from flax import linen as nn\n", "import jax\n", "from jax import numpy as jnp\n", "from jaxopt import loss\n", "from jaxopt import OptaxSolver\n", "from jaxopt import tree_util\n", "\n", "from matplotlib import pyplot as plt\n", "plt.rcParams.update({'font.size': 22})\n", "\n", "import optax\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Show on which platform JAX is running. The code below should take around 3 min to run on GPU but might take longer on CPUs." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX running on GPU\n" ] } ], "source": [ "print(\"JAX running on\", jax.devices()[0].platform.upper())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "Flags = collections.namedtuple(\n", " \"Flags\",\n", " [\n", " \"l2reg\", # amount of L2 regularization in the objective\n", " \"learning_rate\", # learning rate for the Adam optimizer\n", " \"epochs\", # number of passes over the dataset\n", " \"dataset\", # one of \"mnist\", \"kmnist\", \"emnist\", \"fashion_mnist\", \"cifar10\", \"cifar100\"\n", " \"epsilon\", # Adversarial perturbations lie within the infinity-ball of radius epsilon.\n", " \"train_batch_size\", # Batch size at train time\n", " \"test_batch_size\" # Batch size at test time\n", " ])\n", "\n", "FLAGS = Flags(\n", " l2reg=0.0001,\n", " learning_rate=0.001,\n", " epochs=10,\n", " dataset=\"mnist\",\n", " epsilon=0.01,\n", " train_batch_size=128,\n", " test_batch_size=128)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def load_dataset(split, *, is_training, batch_size):\n", " \"\"\"Load dataset using tensorflow_datasets.\"\"\"\n", " version = 3\n", " ds, ds_info = tfds.load(\n", " f\"{FLAGS.dataset}:{version}.*.*\",\n", " as_supervised=True, # remove useless keys\n", " split=split,\n", " with_info=True)\n", " ds = ds.cache().repeat()\n", " if is_training:\n", " ds = ds.shuffle(10 * batch_size, seed=0)\n", " ds = ds.batch(batch_size)\n", " return iter(tfds.as_numpy(ds)), ds_info\n", "\n", "\n", "class CNN(nn.Module):\n", " \"\"\"A simple CNN model.\"\"\"\n", " num_classes: int\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = x.reshape((x.shape[0], -1)) # flatten\n", " x = nn.Dense(features=256)(x)\n", " x = nn.relu(x)\n", " x = nn.Dense(features=self.num_classes)(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make\n", "# it unavailable to JAX.\n", "tf.config.experimental.set_visible_devices([], \"GPU\")\n", "train_ds, ds_info = load_dataset(\"train\", is_training=True,\n", " batch_size=FLAGS.train_batch_size)\n", "test_ds, _ = load_dataset(\"test\", is_training=False,\n", " batch_size=FLAGS.test_batch_size)\n", "input_shape = (1,) + ds_info.features[\"image\"].shape\n", "num_classes = ds_info.features[\"label\"].num_classes\n", "iter_per_epoch_train = ds_info.splits['train'].num_examples // FLAGS.train_batch_size\n", "iter_per_epoch_test = ds_info.splits['test'].num_examples // FLAGS.test_batch_size\n", "\n", "\n", "net = CNN(num_classes)\n", "\n", "@jax.jit\n", "def accuracy(params, data):\n", " inputs, labels = data\n", " logits = net.apply({\"params\": params}, inputs)\n", " return jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n", "\n", "logistic_loss = jax.vmap(loss.multiclass_logistic_loss)\n", "\n", "@jax.jit\n", "def loss_fun(params, l2reg, data):\n", " \"\"\"Compute the loss of the network.\"\"\"\n", " inputs, labels = data\n", " x = inputs.astype(jnp.float32)\n", " logits = net.apply({\"params\": params}, x)\n", " sqnorm = tree_util.tree_l2_norm(params, squared=True)\n", " loss_value = jnp.mean(logistic_loss(labels, logits))\n", " return loss_value + 0.5 * l2reg * sqnorm\n", "\n", "@jax.jit\n", "def pgd_attack(image, label, params, epsilon=0.1, maxiter=10):\n", " \"\"\"PGD attack on the L-infinity ball with radius epsilon.\n", "\n", " Args:\n", " image: array-like, input data for the CNN\n", " label: integer, class label corresponding to image\n", " params: tree, parameters of the model to attack\n", " epsilon: float, radius of the L-infinity ball.\n", " maxiter: int, number of iterations of this algorithm.\n", "\n", " Returns:\n", " perturbed_image: Adversarial image on the boundary of the L-infinity ball\n", " of radius epsilon and centered at image.\n", "\n", " Notes:\n", " PGD attack is described in (Madry et al. 2017),\n", " https://arxiv.org/pdf/1706.06083.pdf\n", " \"\"\"\n", " image_perturbation = jnp.zeros_like(image)\n", " def adversarial_loss(perturbation):\n", " return loss_fun(params, 0, (image + perturbation, label))\n", "\n", " grad_adversarial = jax.grad(adversarial_loss)\n", " for _ in range(maxiter):\n", " # compute gradient of the loss wrt to the image\n", " sign_grad = jnp.sign(grad_adversarial(image_perturbation))\n", "\n", " # heuristic step-size 2 eps / maxiter\n", " image_perturbation += (2 * epsilon / maxiter) * sign_grad\n", " # projection step onto the L-infinity ball centered at image\n", " image_perturbation = jnp.clip(image_perturbation, - epsilon, epsilon)\n", "\n", " # clip the image to ensure pixels are between 0 and 1\n", " return jnp.clip(image + image_perturbation, 0, 1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/jaxopt/lib/python3.10/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 out of 10\n", "Accuracy on train set: 0.982\n", "Accuracy on test set: 0.982\n", "Adversarial accuracy on train set: 0.979\n", "Adversarial accuracy on test set: 0.979\n", "Time elapsed: 0:00:14\n", "\n", "Epoch 1 out of 10\n", "Accuracy on train set: 0.987\n", "Accuracy on test set: 0.986\n", "Adversarial accuracy on train set: 0.984\n", "Adversarial accuracy on test set: 0.982\n", "Time elapsed: 0:00:21\n", "\n", "Epoch 2 out of 10\n", "Accuracy on train set: 0.989\n", "Accuracy on test set: 0.987\n", "Adversarial accuracy on train set: 0.986\n", "Adversarial accuracy on test set: 0.984\n", "Time elapsed: 0:00:29\n", "\n", "Epoch 3 out of 10\n", "Accuracy on train set: 0.992\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.990\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:00:36\n", "\n", "Epoch 4 out of 10\n", "Accuracy on train set: 0.994\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.992\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:00:43\n", "\n", "Epoch 5 out of 10\n", "Accuracy on train set: 0.994\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.992\n", "Adversarial accuracy on test set: 0.987\n", "Time elapsed: 0:00:51\n", "\n", "Epoch 6 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.991\n", "Adversarial accuracy on train set: 0.994\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:00:58\n", "\n", "Epoch 7 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.993\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:01:05\n", "\n", "Epoch 8 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.989\n", "Adversarial accuracy on train set: 0.993\n", "Adversarial accuracy on test set: 0.987\n", "Time elapsed: 0:01:13\n", "\n", "Epoch 9 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.993\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:01:20\n", "\n" ] } ], "source": [ "# Initialize solver and parameters.\n", "solver = OptaxSolver(\n", " opt=optax.adam(FLAGS.learning_rate),\n", " fun=loss_fun,\n", " maxiter=FLAGS.epochs * iter_per_epoch_train)\n", "key = jax.random.PRNGKey(0)\n", "params = net.init(key, jnp.zeros(input_shape))[\"params\"]\n", "\n", "state = solver.init_state(params)\n", "start = datetime.datetime.now().replace(microsecond=0)\n", "jitted_update = jax.jit(solver.update)\n", "\n", "accuracy_train = []\n", "accuracy_test = []\n", "adversarial_accuracy_train = []\n", "adversarial_accuracy_test = []\n", "for it in range(solver.maxiter):\n", " # training loop\n", " images, labels = next(train_ds)\n", " # convert images to float as attack requires to take gradients wrt to them\n", " images = images.astype(jnp.float32) / 255\n", "\n", " adversarial_images_train = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " # train on adversarial images\n", " params, state = jitted_update(\n", " params=params,\n", " state=state,\n", " l2reg=FLAGS.l2reg,\n", " data=(adversarial_images_train, labels))\n", "\n", " # Once per epoch evaluate the model on the train and test sets.\n", " if state.iter_num % iter_per_epoch_train == iter_per_epoch_train - 1:\n", "\n", " # compute train set accuracy, both on clean and adversarial images\n", " adversarial_accuracy_train_sample = 0.\n", " accuracy_train_sample = 0.\n", " for _ in range(iter_per_epoch_train):\n", " images, labels = next(train_ds)\n", " images = images.astype(jnp.float32) / 255\n", " accuracy_train_sample += jnp.mean(accuracy(params, (images, labels))) / iter_per_epoch_train\n", " adversarial_images_train = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " adversarial_accuracy_train_sample += jnp.mean(\n", " accuracy(params, (adversarial_images_train, labels))) / iter_per_epoch_train\n", " accuracy_train.append(accuracy_train_sample)\n", " adversarial_accuracy_train.append(adversarial_accuracy_train_sample)\n", "\n", " # compute train set accuracy, both on clean and adversarial images\n", " adversarial_accuracy_test_sample = 0.\n", " accuracy_test_sample = 0.\n", " for _ in range(iter_per_epoch_test):\n", " images, labels = next(test_ds)\n", " images = images.astype(jnp.float32) / 255\n", " accuracy_test_sample += jnp.mean(accuracy(params, (images, labels))) / iter_per_epoch_test\n", " adversarial_images_test = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " adversarial_accuracy_test_sample += jnp.mean(\n", " accuracy(params, (adversarial_images_test, labels))) / iter_per_epoch_test\n", " accuracy_test.append(accuracy_test_sample)\n", " adversarial_accuracy_test.append(adversarial_accuracy_test_sample)\n", "\n", "\n", " time_elapsed = (datetime.datetime.now().replace(microsecond=0) - start)\n", " print(f\"Epoch {it // iter_per_epoch_train} out of {FLAGS.epochs}\")\n", " print(f\"Accuracy on train set: {accuracy_train[-1]:.3f}\")\n", " print(f\"Accuracy on test set: {accuracy_test[-1]:.3f}\")\n", " print(\n", " f\"Adversarial accuracy on train set: {adversarial_accuracy_train[-1]:.3f}\"\n", " )\n", " print(\n", " f\"Adversarial accuracy on test set: {adversarial_accuracy_test[-1]:.3f}\"\n", " )\n", " print(f\"Time elapsed: {time_elapsed}\")\n", " print()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "\n", "plt.suptitle(\"Adversarial training on \" + f\"{FLAGS.dataset}\".upper())\n", "axes[0].plot(accuracy_train, lw=3, label=\"train set.\" , marker='<', markersize=10)\n", "axes[0].plot(accuracy_test, lw=3, label=\"test set.\", marker='d', markersize=10)\n", "axes[0].grid()\n", "axes[0].set_ylabel('accuracy on clean images')\n", "\n", "axes[1].plot(\n", " adversarial_accuracy_train,\n", " lw=3,\n", " label=\"adversarial accuracy on train set.\", marker='^', markersize=10)\n", "axes[1].plot(\n", " adversarial_accuracy_test,\n", " lw=3,\n", " label=\"adversarial accuracy on test set.\", marker='>', markersize=10)\n", "axes[1].grid()\n", "axes[0].legend(frameon=False, ncol=2, loc='upper center', bbox_to_anchor=(0.8, -0.1))\n", "axes[0].set_xlabel('epochs')\n", "axes[1].set_ylabel('accuracy on adversarial images')\n", "plt.subplots_adjust( wspace=0.5 )\n", "\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Find a test set image that is correctly classified but not its adversarial perturbation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def find_adversarial_imgs():\n", " for _ in range(iter_per_epoch_test):\n", " images, labels = next(test_ds)\n", " images = images.astype(jnp.float32) / 255\n", " logits = net.apply({\"params\": params}, images)\n", " labels_clean = jnp.argmax(logits, axis=-1)\n", "\n", " adversarial_images = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " labels_adversarial = jnp.argmax(net.apply({\"params\": params}, adversarial_images), axis=-1)\n", " idx_misclassified = jnp.where(labels_clean != labels_adversarial)[0]\n", " if len(idx_misclassified) == 0:\n", " continue\n", " else:\n", " for i in idx_misclassified:\n", " img_clean = images[i]\n", " prediction_clean = labels_clean[i]\n", " if prediction_clean != labels[i]:\n", " # the clean image predicts the wrong label, skip\n", " continue\n", " img_adversarial = adversarial_images[i]\n", " prediction_adversarial = labels_adversarial[i]\n", " # we found our image\n", " return img_clean, prediction_clean, img_adversarial, prediction_adversarial\n", "\n", " raise ValueError(\"No mismatch between clean and adversarial prediction found\")\n", "\n", "img_clean, prediction_clean, img_adversarial, prediction_adversarial = \\\n", " find_adversarial_imgs()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_, axes = plt.subplots(nrows=1, ncols=3, figsize=(6*3, 6))\n", "\n", "axes[0].set_title('Clean image \\n Prediction %s' % int(prediction_clean))\n", "axes[0].imshow(img_clean, cmap=plt.cm.get_cmap('Greys'), vmax=1, vmin=0)\n", "axes[1].set_title('Adversarial image \\n Prediction %s' % prediction_adversarial)\n", "axes[1].imshow(img_adversarial, cmap=plt.cm.get_cmap('Greys'), vmax=1, vmin=0)\n", "axes[2].set_title(r'|Adversarial - clean| $\\times$ %.0f' % (1/FLAGS.epsilon))\n", "axes[2].imshow(jnp.abs(img_clean - img_adversarial) / FLAGS.epsilon, cmap=plt.cm.get_cmap('Greys'), vmax=1, vmin=0)\n", "for i in range(3):\n", " axes[i].set_xticks(())\n", " axes[i].set_yticks(())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 4 }