{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Copyright 2022 Google LLC\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "CKB49u-HsfEb" }, "source": [ "# Dataset distillation\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jaxopt/blob/main/docs/notebooks/implicit_diff/dataset_distillation.ipynb)\n", "\n", "\n", "\n", "\n", "Dataset distillation [Maclaurin et al. 2015](https://arxiv.org/pdf/1502.03492.pdf), [Wang et al. 2020](https://arxiv.org/pdf/1811.10959.pdf) aims to learn a small synthetic\n", "training dataset such that a model trained on this learned data set achieves\n", "small loss on the original training set." ] }, { "cell_type": "markdown", "metadata": { "id": "T_1ezvj0ut0L" }, "source": [ "**Bi-level formulation**\n", "\n", "Dataset distillation can be written formally as a bi-level problem, where in the\n", "inner problem we estimate a logistic regression model $x^\\star(\\theta) \\in\n", "\\mathbb{R}^{p \\times k}$ trained on the distilled images $\\theta \\in\n", "\\mathbb{R}^{k \\times p}$, while in the outer problem we want to minimize the\n", "loss achieved by $x^\\star(\\theta)$ over the training set:\n", "\n", "$$\\underbrace{\\min_{\\theta \\in \\mathbb{R}^{k \\times p}} f(x^\\star(\\theta), X_{\\text{tr}}; y_{\\text{tr}})}_{\\text{outer problem}} ~\\text{ subject to }~ x^\\star(\\theta) \\in \\underbrace{\\text{argmin}_{x \\in \\mathbb{R}^{p \\times k}} f(x, \\theta; [k]) + \\text{l2reg} \\|x\\|^2\\,}_{\\text{inner problem}}$$\n", "\n", "where $f(W, X; y) := \\ell(y, XW)$, and $\\ell$ denotes the multiclass\n", "logistic regression loss, $X_{\\text{tr}}, y_{\\text{tr}}$ are the samples and\n", "target values in the train set, and $\\text{l2reg} = 10^{-1}$ is a regularization\n", "parameter that we found improved convergence." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "iQvA16DP8zhC" }, "outputs": [], "source": [ "#@title Imports\n", "%%capture\n", "%pip install jaxopt flax" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "7lXrLlDi9FiC" }, "outputs": [], "source": [ "import itertools\n", "import tensorflow_datasets as tfds\n", "from matplotlib import pyplot as plt\n", "\n", "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")\n", "\n", "import jax\n", "from jax import numpy as jnp\n", "\n", "from jaxopt import GradientDescent\n", "from jaxopt import objective\n", "\n", "jax.config.update(\"jax_platform_name\", \"cpu\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 208, "referenced_widgets": [ "c334015f1ea947bb989e477ebb15a686", "acb448d7153e4e46b54521f98afca562", "844222f868204ef9a33eb47f76d87db8", "f9e172ac2805409ea0e164622648d414", "1f67bb4c53124dbaa2e5ae2bfd0c298f", "aa8282fe603043b18ae4c6ee03a0ae65", "bd2824a64f5a4d1b9bee416077653861", "c2a3133a84454a61904be05661aa6982", "d5a08e82176b410d820189a4c653feb3", "59a557f2a61b4c7785f8e7f3cdf55928", "a956598eafb84694a6b21bbc8a375fb3" ] }, "id": "EQCtC92k9iXJ", "outputId": "b9b5ded7-3cac-4193-a2a6-c0cbcabd711e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your\n", "local data directory. If you'd instead prefer to read directly from our public\n", "GCS bucket (recommended if you're running on GCP), you can instead pass\n", "`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c334015f1ea947bb989e477ebb15a686", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dl Completed...: 0%| | 0/4 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the learnt images\n", "fig, axarr = plt.subplots(2, 5, figsize=(10 * 5, 2 * 10))\n", "plt.suptitle(\"Distilled images\", fontsize=40)\n", "\n", "for k, (i, j) in enumerate(itertools.product(range(2), range(5))):\n", " img_i = distilled_images[k].reshape((28, 28))\n", " axarr[i, j].imshow(\n", " img_i / jnp.abs(img_i).max(), cmap=plt.cm.gray_r, vmin=-1, vmax=1)\n", " axarr[i, j].set_xticks(())\n", " axarr[i, j].set_yticks(())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0bM6pDvjYbCr" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "plot_dataset_distillation.ipynb", "provenance": [], "toc_visible": true }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } }, "widgets": { "application/vnd.jupyter.widget-state+json": { "1f67bb4c53124dbaa2e5ae2bfd0c298f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "59a557f2a61b4c7785f8e7f3cdf55928": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "844222f868204ef9a33eb47f76d87db8": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c2a3133a84454a61904be05661aa6982", "max": 4, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d5a08e82176b410d820189a4c653feb3", "value": 4 } }, "a956598eafb84694a6b21bbc8a375fb3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "aa8282fe603043b18ae4c6ee03a0ae65": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "acb448d7153e4e46b54521f98afca562": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_aa8282fe603043b18ae4c6ee03a0ae65", "placeholder": "​", "style": "IPY_MODEL_bd2824a64f5a4d1b9bee416077653861", "value": "Dl Completed...: 100%" } }, "bd2824a64f5a4d1b9bee416077653861": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c2a3133a84454a61904be05661aa6982": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c334015f1ea947bb989e477ebb15a686": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_acb448d7153e4e46b54521f98afca562", "IPY_MODEL_844222f868204ef9a33eb47f76d87db8", "IPY_MODEL_f9e172ac2805409ea0e164622648d414" ], "layout": "IPY_MODEL_1f67bb4c53124dbaa2e5ae2bfd0c298f" } }, "d5a08e82176b410d820189a4c653feb3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "f9e172ac2805409ea0e164622648d414": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_59a557f2a61b4c7785f8e7f3cdf55928", "placeholder": "​", "style": "IPY_MODEL_a956598eafb84694a6b21bbc8a375fb3", "value": " 4/4 [00:00<00:00, 3.45 file/s]" } } } } }, "nbformat": 4, "nbformat_minor": 0 }