{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Copyright 2022 Google LLC\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", "\n", " https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Perturbed optimizers" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jaxopt/blob/main/docs/notebooks/perturbed_optimizers/perturbed_optimizers.ipynb)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We review in this notebook a universal method to transform any optimizer $y^*$ in a differentiable approximation $y_\\varepsilon^*$, using pertutbations following the method of [Berthet et al. (2020)](https://arxiv.org/abs/2002.08676). JAXopt provides an implementation that we illustrate here on some examples.\n", "\n", "Concretely, for an optimizer function $y^*$ defined by\n", "\n", "$$y^*(\\theta) = \\mathop{\\mathrm{arg\\,max}}_{y\\in \\mathcal{C}} \\langle y, \\theta \\rangle\\, ,$$\n", "\n", "we consider, for a random $Z$ drawn from a distribution with continuous positive distribution $\\mu$\n", "\n", "$$y_\\varepsilon^*(\\theta) = \\mathbf{E}[\\mathop{\\mathrm{arg\\,max}}_{y\\in \\mathcal{C}} \\langle y, \\theta + \\varepsilon Z \\rangle]$$" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "9WIWwRdSU51j" }, "outputs": [], "source": [ "%%capture\n", "%pip install jaxopt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "S6tLyyy9VCEw" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import jaxopt\n", "import time\n", "\n", "from jaxopt import perturbations" ] }, { "cell_type": "markdown", "metadata": { "id": "EmYn_jNUFfw2" }, "source": [ "# Argmax one-hot" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We consider an optimizer, such as the following `argmax_one_hot` function. It transforms a real-valued vector into a binary vector with a 1 in the coefficient with largest magnitude and 0 elsewhere. It corresponds to $y^*$ for $\\mathcal{C}$ being the unit simplex. We run it on an example input `values`." ] }, { "cell_type": "markdown", "metadata": { "id": "84N-wAJ8GDK2" }, "source": [ "## One-hot function" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "kMZnzhX4FjGj" }, "outputs": [], "source": [ "def argmax_one_hot(x, axis=-1):\n", " return jax.nn.one_hot(jnp.argmax(x, axis=axis), x.shape[axis])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iynCk8734Wiz", "outputId": "31e902b6-2dfa-4340-b11e-955b422bd313" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 1. 0. 0. 0.]\n" ] } ], "source": [ "values = jnp.array([-0.6, 1.9, -0.2, 1.1, -1.0])\n", "\n", "one_hot_vec = argmax_one_hot(values)\n", "print(one_hot_vec)" ] }, { "cell_type": "markdown", "metadata": { "id": "6rbNt-6zGb-J" }, "source": [ "## One-hot with pertubations" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Our implementation transforms the `argmax_one_hot` function into a perturbed one that we call `pert_one_hot`. In this case we use Gumbel noise for the perturbation." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "7hQz6zuPwkpZ" }, "outputs": [], "source": [ "N_SAMPLES = 100_000\n", "SIGMA = 0.5\n", "GUMBEL = perturbations.Gumbel()\n", "\n", "rng = jax.random.PRNGKey(1)\n", "pert_one_hot = perturbations.make_perturbed_argmax(argmax_fun=argmax_one_hot,\n", " num_samples=N_SAMPLES,\n", " sigma=SIGMA,\n", " noise=GUMBEL)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In this particular case, it is equal to the usual [softmax function](https://en.wikipedia.org/wiki/Softmax_function). This is not always true, in general there is no closed form for $y_\\varepsilon^*$" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f2gDpghJYZ33", "outputId": "adbe4b3c-4ec9-4f18-ea75-07619fd84cb3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "computation with 100000 samples, sigma = 0.5\n", "perturbed argmax = [0.0055 0.81842 0.01212 0.16157 0.00239]\n", "softmax = [0.00549293 0.8152234 0.01222475 0.16459078 0.00246813]\n", "square norm of softmax = 8.32e-01\n", "square norm of difference = 4.40e-03\n" ] } ], "source": [ "rngs = jax.random.split(rng, 2)\n", "\n", "rng = rngs[0]\n", "\n", "pert_argmax = pert_one_hot(values, rng)\n", "print(f'computation with {N_SAMPLES} samples, sigma = {SIGMA}')\n", "print(f'perturbed argmax = {pert_argmax}')\n", "jax.nn.softmax(values/SIGMA)\n", "soft_max = jax.nn.softmax(values/SIGMA)\n", "print(f'softmax = {soft_max}')\n", "print(f'square norm of softmax = {jnp.linalg.norm(soft_max):.2e}')\n", "print(f'square norm of difference = {jnp.linalg.norm(pert_argmax - soft_max):.2e}')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "2U7rhtEAGpMV" }, "source": [ "## Gradients for one-hot with perturbations" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The perturbed optimizer $y_\\varepsilon^*$ is differentiable, and its gradient can be computed with stochastic estimation automatically, using `jax.grad`.\n", "\n", "We create a scalar loss `loss_simplex` of the perturbed optimizer $y^*_\\varepsilon$\n", "\n", "$$\\ell_\\text{simplex}(y_{\\text{true}} = y_\\varepsilon^*; y_{\\text{true}})$$ \n", "\n", "For `values` equal to a vector $\\theta$, we can compute gradients of \n", "\n", "$$\\ell(\\theta) = \\ell_\\text{simplex}(y_\\varepsilon^*(\\theta); y_{\\text{true}})$$\n", "with respect to `values`, automatically." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7H1LD4QhGtFI", "outputId": "fda1ce54-9e6c-4736-f3ff-05258129ee27" }, "outputs": [ { "data": { "text/plain": [ "Array(0.5865911, dtype=float32)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Example loss function\n", "\n", "def loss_simplex(values, rng):\n", " n = values.shape[0]\n", " v_true = jnp.arange(n) + 2\n", " y_true = v_true / jnp.sum(v_true)\n", " y_pred = pert_one_hot(values, rng)\n", " return jnp.sum((y_true - y_pred) ** 2)\n", "\n", "loss_simplex(values, rngs[1])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can compute the gradient of $\\ell$ directly\n", "\n", "$$\\nabla_\\theta \\ell(\\theta) = \\partial_\\theta y^*_\\varepsilon(\\theta) \\cdot \\nabla_1 \\ell_{\\text{simplex}}(y^*_\\varepsilon(\\theta); y_{\\text{true}})$$\n", "\n", "The computation of the jacobian $\\partial_\\theta y^*_\\varepsilon(\\theta)$ is implemented automatically, using an estimation method given by [Berthet et al. (2020)](https://arxiv.org/abs/2002.08676), [Prop. 3.1]." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tjQatCE3GtFJ", "outputId": "b2454c6b-e3a9-4ed7-ca31-c49a24ce4594" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.02052322 0.46736273 -0.02747887 -0.39873555 -0.00571656]\n" ] } ], "source": [ "# Gradient of the loss w.r.t input values\n", "\n", "gradient = jax.grad(loss_simplex)(values, rngs[1])\n", "print(gradient)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We illustrate the use of this method by running 200 steps of gradient descent on $\\theta_t$ so that it minimizes this loss." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "MuNE2RX0GtFJ" }, "outputs": [], "source": [ "# Doing 200 steps of gradient descent on the values to have the desired ranks\n", "\n", "steps = 200\n", "values_t = values\n", "eta = 0.5\n", "\n", "grad_func = jax.jit(jax.grad(loss_simplex))\n", "\n", "for t in range(steps):\n", " rngs = jax.random.split(rngs[1], 2)\n", " values_t = values_t - eta * grad_func(values_t, rngs[1])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "29TWHiH0GtFJ", "outputId": "e53faab6-ebae-4658-d0ce-48f2b06b6a8a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial values = [-0.6 1.9 -0.2 1.1 -1. ]\n", "initial one-hot = [0. 1. 0. 0. 0.]\n", "initial diff. one-hot = [0.0056 0.81554997 0.01239 0.16398999 0.00247 ]\n", "\n", "values after GD = [-0.07073233 0.13270897 0.2768847 0.38671777 0.4782017 ]\n", "ranks after GD = [0. 0. 0. 0. 1.]\n", "diff. one-hot after GD = [0.09843 0.15089999 0.19826 0.25197 0.30043998]\n", "target diff. one-hot = [0.1 0.15 0.2 0.25 0.3 ]\n" ] } ], "source": [ "rngs = jax.random.split(rngs[1], 2)\n", "\n", "n = values.shape[0]\n", "v_true = jnp.arange(n) + 2\n", "y_true = v_true / jnp.sum(v_true)\n", "\n", "print(f'initial values = {values}')\n", "print(f'initial one-hot = {argmax_one_hot(values)}')\n", "print(f'initial diff. one-hot = {pert_one_hot(values, rngs[0])}')\n", "print()\n", "print(f'values after GD = {values_t}')\n", "print(f'ranks after GD = {argmax_one_hot(values_t)}')\n", "print(f'diff. one-hot after GD = {pert_one_hot(values_t, rngs[1])}')\n", "print(f'target diff. one-hot = {y_true}')" ] }, { "cell_type": "markdown", "metadata": { "id": "4Vyh_a1bZT-s" }, "source": [ "# Differentiable ranking" ] }, { "cell_type": "markdown", "metadata": { "id": "QmVAjbJxFzUA" }, "source": [ "## Ranking function" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We consider an optimizer, such as the following `ranking` function. It transforms a real-valued vector of size $n$ into a vector with coefficients being a permutation of $\\{0,\\ldots, n-1\\}$ corresponding to the order of the coefficients of the original vector. It corresponds to $y^*$ for $\\mathcal{C}$ being the permutahedron. We run it on an example input `values`." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "-NKbR6TlZUTG" }, "outputs": [], "source": [ "# Function outputting a vector of ranks\n", "\n", "def ranking(values):\n", " return jnp.argsort(jnp.argsort(values))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iU69uMAoZncY", "outputId": "70d4a0ec-c335-48c1-d04d-b5fa32e951c4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "values = [ 0.18784384 -1.2833426 0.6494181 1.2490593 0.24447003 -0.11744965]\n", "ranking = [2 0 4 5 3 1]\n" ] } ], "source": [ "# Example on random values\n", "\n", "n = 6\n", "\n", "rng = jax.random.PRNGKey(0)\n", "values = jax.random.normal(rng, (n,))\n", "\n", "print(f'values = {values}')\n", "print(f'ranking = {ranking(values)}')" ] }, { "cell_type": "markdown", "metadata": { "id": "5j1Vgfz_bb9u" }, "source": [ "## Ranking with perturbations" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As above, our implementation transforms this function into a perturbed one that we call `pert_ranking`. In this case we use Gumbel noise for the perturbation." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "Equ3_gDPbf5n" }, "outputs": [], "source": [ "N_SAMPLES = 100\n", "SIGMA = 0.2\n", "GUMBEL = perturbations.Gumbel()\n", "\n", "pert_ranking = perturbations.make_perturbed_argmax(ranking,\n", " num_samples=N_SAMPLES,\n", " sigma=SIGMA,\n", " noise=GUMBEL)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vMj-Dnudby_a", "outputId": "827458ad-de4b-4b85-be72-0373c2a4b7f8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "values = [ 0.18784384 -1.2833426 0.6494181 1.2490593 0.24447003 -0.11744965]\n", "diff_ranks = [2.37 0.02 3.85 4.96 2.4099998 1.39 ]\n" ] } ], "source": [ "# Expectation of the perturbed ranks on these values\n", "\n", "rngs = jax.random.split(rng, 2)\n", "\n", "diff_ranks = pert_ranking(values, rngs[0])\n", "print(f'values = {values}')\n", "\n", "print(f'diff_ranks = {diff_ranks}')" ] }, { "cell_type": "markdown", "metadata": { "id": "aH6Ew85koQvU" }, "source": [ "## Gradients for ranking with perturbations" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As above, the perturbed optimizer $y_\\varepsilon^*$ is differentiable, and its gradient can be computed with stochastic estimation automatically, using `jax.grad`.\n", "\n", "We showcase this on a loss of $y_\\varepsilon(\\theta)$ that can be directly differentiated w.r.t. the `values` equal to $\\theta$." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O-T8y6N8cHzF", "outputId": "ad780191-cac8-40a9-ecfc-1b3442732aa5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28.9336\n" ] } ], "source": [ "# Example loss function\n", "\n", "def loss_example(values, rng):\n", " n = values.shape[0]\n", " y_true = ranking(jnp.arange(n))\n", " y_pred = pert_ranking(values, rng)\n", " return jnp.sum((y_true - y_pred) ** 2)\n", "\n", "print(loss_example(values, rngs[1]))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v7nzNwP-e68q", "outputId": "1fc9ff56-c223-49f1-9acc-00a959e44c26" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 14.139462 -2.6558158 -19.498537 16.295418 -2.338868 -21.901724 ]\n" ] } ], "source": [ "# Gradient of the objective w.r.t input values\n", "\n", "gradient = jax.grad(loss_example)(values, rngs[1])\n", "print(gradient)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As above, we showcase this example on gradient descent to minimize this loss." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "0UObBP3QfCqq" }, "outputs": [], "source": [ "steps = 20\n", "values_t = values\n", "eta = 0.1\n", "\n", "grad_func = jax.jit(jax.grad(loss_example))\n", "\n", "for t in range(steps):\n", " rngs = jax.random.split(rngs[1], 2)\n", " values_t = values_t - eta * grad_func(values_t, rngs[1])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p4iNxMoQmZRa", "outputId": "16f75357-192d-4b5d-8467-20e5ab639e9b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial values = [ 0.18784384 -1.2833426 0.6494181 1.2490593 0.24447003 -0.11744965]\n", "initial ranks = [2 0 4 5 3 1]\n", "initial diff. ranks = [2.44 0. 3.79 4.98 2.51 1.28]\n", "\n", "values after GD = [-2.9923885 -1.9453204 -1.259742 -0.69805354 0.33311206 1.7650208 ]\n", "ranks after GD = [0 1 2 3 4 5]\n", "diff. ranks after GD = [0. 1.01 2.05 2.95 3.99 5. ]\n", "target diff. ranks = [0 1 2 3 4 5]\n" ] } ], "source": [ "rngs = jax.random.split(rngs[1], 2)\n", "\n", "y_true = ranking(jnp.arange(n))\n", "\n", "print(f'initial values = {values}')\n", "print(f'initial ranks = {ranking(values)}')\n", "print(f'initial diff. ranks = {pert_ranking(values, rngs[0])}')\n", "print()\n", "print(f'values after GD = {values_t}')\n", "print(f'ranks after GD = {ranking(values_t)}')\n", "print(f'diff. ranks after GD = {pert_ranking(values_t, rngs[1])}')\n", "print(f'target diff. ranks = {y_true}')" ] } ], "metadata": { "colab": { "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 0 }