{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "SDYtl1XHEEjK" }, "outputs": [], "source": [ "# Copyright 2022 Google LLC.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "nCt2V9cCR-CF" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jaxopt/blob/main/docs/notebooks/distributed/custom_loop_pmap_example.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "N3RErasEJpMo" }, "source": [ "`jax.pmap` example using JAXopt.\n", "=========================================\n", "The purpose of this example is to illustrate how JAXopt solvers can be easily\n", "used for distributed training thanks to `jax.pmap`. In this case, we begin by\n", "implementing data parallel training of a multi-class logistic regression model\n", "on synthetic data. General aspects to pay attention to include:\n", "* How to use `jax.lax` reduction operators such as `jax.lax.pmean` or\n", " `jax.lax.psum` in JAXopt solvers by using custom `value_and_grad` functions.\n", "* How `jax.pmap` can be used to transform the solver's `update` method to easily\n", " write custom data-parallel training loops.\n", "\n", "To obtain the best performance on Google Colab we recommend:\n", "1. In `Change runtime type` under the menu `Runtime`, select `TPU` for the `Hardware accelerator` option.\n", "2. Connect to the runtime and run all cells. \n", "\n", "NOTE: this example can be easily adapted to support TPU pod slices (e.g. `--accelerator_type v3-32`) as well as hosts with one or more GPUs attached." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "LD0juyi9JmIZ" }, "outputs": [], "source": [ "#@markdown The number of optimization steps to perform:\n", "MAXITER = 100 #@param{type:\"integer\"}\n", "#@markdown The number of samples in the (synthetic) dataset:\n", "NUM_SAMPLES = 50000 #@param{type:\"integer\"}\n", "#@markdown The number of features in the (synthetic) dataset:\n", "NUM_FEATURES = 784 #@param{type:\"integer\"}\n", "#@markdown The number of classes in the (synthetic) dataset:\n", "NUM_CLASSES = 10 #@param{type:\"integer\"}\n", "#@markdown The stepsize for the optimizer (set to 0.0 to use line search):\n", "STEPSIZE = 0.0 #@param{type:\"number\"}\n", "#@markdown The line search approach (either `'zoom'` or `backtracking`), ignored if `STEPSIZE > 0.0`:\n", "LINESEARCH = 'zoom' #@param{type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "9cSpF3jQMezP" }, "source": [ "# Imports and TPU setup" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "AkeC41JXNTv0" }, "outputs": [], "source": [ "%%capture\n", "%pip install jaxopt flax" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "6UypR4hUNbve" }, "outputs": [], "source": [ "import functools\n", "import time\n", "from typing import Any, Callable, Tuple, Union\n", "\n", "from absl import app\n", "from absl import flags\n", "\n", "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")\n", "\n", "from flax import jax_utils\n", "from flax.training import common_utils\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import jax.tools.colab_tpu\n", "\n", "import jaxopt\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "from sklearn import datasets" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iMxI2ZU0JnJO", "outputId": "97d0bdd8-5b35-4766-ff0d-f020de40f79c" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.tools.colab_tpu.setup_tpu()\n", "jax.local_devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "8ZReXwUIO4Tt" }, "source": [ "# Type aliases" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "5A_PBZ7OO5rO" }, "outputs": [], "source": [ "Array = Union[np.ndarray, jax.Array]" ] }, { "cell_type": "markdown", "metadata": { "id": "gbqilnNSNpw1" }, "source": [ "# Auxiliary functions\n", "A minimal working example of how all-reduce mean/sum OPs can be introduced into JAXopt solver's `update` method by overriding `jax.value_and_grad`. Note that more complex wrappers are of course possible." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "awxxl498NwMI" }, "outputs": [], "source": [ "def pmean(fun: Callable[..., Any], axis_name: str = 'b') -> Callable[..., Any]:\n", " \"\"\"Applies `jax.lax.pmean` across `axis_name` for all of `fun`'s outputs.\"\"\"\n", " maybe_pmean = lambda t: jax.lax.pmean(t, axis_name) if t is not None else t\n", " @functools.wraps(fun)\n", " def wrapper(*args, **kwargs):\n", " return jax.tree_map(maybe_pmean, fun(*args, **kwargs))\n", " return wrapper" ] }, { "cell_type": "markdown", "metadata": { "id": "clQeXb77ObhK" }, "source": [ "A small utility to shard `Array`s across the available devices:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "rMntHIEFOeyO" }, "outputs": [], "source": [ "def shard_array(array: Array) -> jax.Array:\n", " \"\"\"Shards `array` along its leading dimension.\"\"\"\n", " return jax.device_put_sharded(\n", " shards=list(common_utils.shard(array)),\n", " devices=jax.devices())" ] }, { "cell_type": "markdown", "metadata": { "id": "4Pph9LUAOmhi" }, "source": [ "# Custom-loop\n", "The following code uses data-parallelism in the train loop. Through the `use_pmap` keyword argument we can deactivate this parallelism. We'll use this feature later to benchmark the impact of parallelism." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "tv1n_kkJO0Pf" }, "outputs": [], "source": [ "def fit(\n", " data: Tuple[Array, Array],\n", " init_params: Array,\n", " stepsize: float = 0.0,\n", " linesearch: str = 'zoom',\n", " use_pmap: bool = False,\n", ") -> Tuple[np.ndarray, np.ndarray, float]:\n", " \"\"\"Fits a multi-class logistic regression model for demonstration purposes.\n", "\n", " Args:\n", " data: A tuple `(X, y)` with the training covariates and categorical labels,\n", " respectively.\n", " init_params: The initial value of the model's weights.\n", " stepsize: The stepsize to use for the solver. If set to `0`, linesearch will\n", " be used instead.\n", " linesearch: The linesearch algorithm to use. If `stepsize > 0`, linsearch\n", " will be disabled.\n", " use_pmap: Whether to distribute the computation across replicas or use only\n", " the first device available.\n", " \n", " Returns:\n", " The per-step errors and runtimes, as well as the JIT-compile time for the\n", " solver's `update` function.\n", " \"\"\"\n", " # Value and grad of the objective function for the solver.\n", " value_and_grad_fun = jax.value_and_grad(jaxopt.objective.multiclass_logreg)\n", " # When `jax.pmap`ing the computation, use JAXopt's option to provide a custom\n", " # `value_and_grad` function to include the desired reduction operators. For\n", " # example, here we decide to average across replicas.\n", " if use_pmap:\n", " value_and_grad_fun = pmean(value_and_grad_fun)\n", "\n", " # To override `jax.value_and_grad` in a JAXopt solver, set the flag \n", " # `value_and_grad` to `True` and pass the custom implementation of the\n", " # `value_and_grad` function as `fun`.\n", " solver = jaxopt.LBFGS(fun=value_and_grad_fun,\n", " value_and_grad=True,\n", " stepsize=stepsize,\n", " linesearch=linesearch)\n", " # Apply the `jax.pmap` transform to the function to be computed in a \n", " # distributed manner (the solver's `update` method in this case). Otherwise,\n", " # we JIT compile it.\n", " if use_pmap:\n", " update = jax.pmap(solver.update, axis_name='b')\n", " else:\n", " update = jax.jit(solver.update)\n", "\n", " # Initialize solver state.\n", " state = solver.init_state(init_params, data=data)\n", " params = init_params\n", " # If using `pmap` for data-parallel training, model parameters are typically\n", " # replicated across devices.\n", " if use_pmap:\n", " params, state = jax_utils.replicate((params, state))\n", " \n", " # Finally, since in this demo we are *not* using mini-batches, it pays off to\n", " # transfer data to the device beforehand. Otherwise, host-to-device transfers\n", " # occur in each update. This is true regardless of whether we use distributed\n", " # or single-device computation.\n", " if use_pmap: # Shards data and moves it to device,\n", " data = jax.tree_map(shard_array, data)\n", " else: # Just move data to device.\n", " data = jax.tree_map(jax.device_put, data)\n", "\n", " # Pre-compiles update, preventing it from affecting step times.\n", " tic = time.time()\n", " _ = update(params, state, data)\n", " compile_time = time.time() - tic\n", "\n", " outer_tic = time.time()\n", "\n", " step_times = np.zeros(MAXITER)\n", " errors = np.zeros(MAXITER)\n", " for it in range(MAXITER):\n", " tic = time.time()\n", " params, state = update(params, state, data)\n", " jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n", " step_times[it] = time.time() - tic\n", " errors[it] = (jax_utils.unreplicate(state.error).item()\n", " if use_pmap else state.error.item())\n", "\n", " print(\n", " f'Total time elapsed with {linesearch} linesearch and pmap = {use_pmap}:',\n", " round(time.time() - outer_tic, 2), 'seconds.')\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "1sEssbjlPOvt" }, "source": [ "# Boilerplate\n", "Creates dataset, calls `fit` with and without `jax.pmap`, makes figures." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "jlz80AwbPRej" }, "outputs": [], "source": [ "def run():\n", " \"\"\"Boilerplate to run the demo experiment.\"\"\"\n", " data = datasets.make_classification(n_samples=NUM_SAMPLES,\n", " n_features=NUM_FEATURES,\n", " n_classes=NUM_CLASSES,\n", " n_informative=50,\n", " random_state=0)\n", " init_params = jnp.zeros([NUM_FEATURES, NUM_CLASSES])\n", " \n", " errors, step_times, compile_time = {}, {}, {}\n", "\n", " for use_pmap in (True, False):\n", " exp_name: str = f\"{'with' if use_pmap else 'without'}_pmap\"\n", " _errors, _step_times, _compile_time = fit(data=data,\n", " init_params=init_params,\n", " stepsize=STEPSIZE,\n", " linesearch=LINESEARCH,\n", " use_pmap=use_pmap)\n", " errors[exp_name] = _errors\n", " step_times[exp_name] = _step_times\n", " compile_time[exp_name] = _compile_time\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pmap in (True, False):\n", " exp_name: str = f\"{'with' if use_pmap else 'without'}_pmap\"\n", " plt.plot(jnp.arange(MAXITER), errors[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Gradient error', fontsize=16)\n", " plt.yscale('log')\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pmap in (True, False):\n", " exp_name: str = f\"{'with' if use_pmap else 'without'}_pmap\"\n", " plt.plot(jnp.arange(MAXITER), step_times[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Step time', fontsize=16)\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "v6Oo8E5hPbyd" }, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "byYCECkWPcbe", "outputId": "48b17428-0dff-4573-a8c4-d8c59f0f6c74" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num_samples: 50000\n", "num_features: 784\n", "num_features: 10\n", "maxiter: 100\n", "stepsize: 0.0\n", "linesearch (ignored if `stepsize` > 0): zoom\n", "\n", "Total time elapsed with zoom linesearch and pmap = True: 5.88 seconds.\n", "Total time elapsed with zoom linesearch and pmap = False: 20.39 seconds.\n", "Average speed-up (ignoring compile): 7.08\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "print(\"num_samples:\", NUM_SAMPLES)\n", "print(\"num_features:\", NUM_FEATURES)\n", "print(\"num_features:\", NUM_CLASSES)\n", "print(\"maxiter:\", MAXITER)\n", "print(\"stepsize:\", STEPSIZE)\n", "print(\"linesearch (ignored if `stepsize` > 0):\", LINESEARCH)\n", "print()\n", "\n", "errors, step_times, compile_time = run()\n", "print('Average speed-up (ignoring compile):',\n", " round((step_times['without_pmap'] / step_times['with_pmap']).mean(), 2))" ] } ], "metadata": { "accelerator": "TPU", "colab": { "provenance": [] }, "gpuClass": "standard", "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 0 }