{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "executionInfo": { "elapsed": 2, "status": "ok", "timestamp": 1671547917546, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "DwcrPoWd4r4I" }, "outputs": [], "source": [ "# Copyright 2022 Google LLC.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "gmlhojcwL5oT" }, "source": [ "`jax.experimental.pjit` example using JAXopt.\n", "=========================================\n", "The purpose of this example is to illustrate how JAXopt solvers can be easily\n", "used for distributed training thanks to `jax.experimental.pjit`. In this case, we begin by\n", "implementing data parallel training of a multi-class logistic regression model\n", "on synthetic data.\n", "\n", "**NOTE: `jax.experimental.pjit` is not yet supported on Google Colab. Please connect to Google Cloud TPUs to execute the example.**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "executionInfo": { "elapsed": 54, "status": "ok", "timestamp": 1671547917733, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "3Ck_x_LyL7PE" }, "outputs": [], "source": [ "#@markdown The number of optimization steps to perform:\n", "MAXITER = 100 #@param{type:\"integer\"}\n", "#@markdown The number of samples in the (synthetic) dataset:\n", "NUM_SAMPLES = 50000 #@param{type:\"integer\"}\n", "#@markdown The number of features in the (synthetic) dataset:\n", "NUM_FEATURES = 784 #@param{type:\"integer\"}\n", "#@markdown The number of classes in the (synthetic) dataset:\n", "NUM_CLASSES = 10 #@param{type:\"integer\"}\n", "#@markdown The stepsize for the optimizer (set to 0.0 to use line search):\n", "STEPSIZE = 0.0 #@param{type:\"number\"}\n", "#@markdown The line search approach (either `'zoom'` or `backtracking`), ignored if `STEPSIZE > 0.0`:\n", "LINESEARCH = 'zoom' #@param{type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "bZmsaMOcONjQ" }, "source": [ "# Imports and TPU setup" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "executionInfo": { "elapsed": 1, "status": "ok", "timestamp": 1671547917854, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "J9uYAd6TBEEG" }, "outputs": [], "source": [ "%%capture\n", "%pip install jaxopt" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "executionInfo": { "elapsed": 40809, "status": "ok", "timestamp": 1671547958814, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "JJuP-Wz_MBeJ" }, "outputs": [], "source": [ "import time\n", "from typing import Any, Callable, Tuple, Union\n", "\n", "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "import jaxopt\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "from sklearn import datasets" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "executionInfo": { "elapsed": 2591, "status": "ok", "timestamp": 1671547961531, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "Wi9vI0SAMEMX", "outputId": "197ca951-400d-48c7-bf75-24ea7bc3a0c1" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.local_devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "RF4MUVeTNjah" }, "source": [ "# Type aliases" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1671547961666, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "_NmZfvmkNlI-" }, "outputs": [], "source": [ "Array = Union[np.ndarray, jax.Array]" ] }, { "cell_type": "markdown", "metadata": { "id": "JkhZJHaxNmmU" }, "source": [ "# Auxiliary functions\n", "A minimal working example of how to create a `Mesh` for data parallel execution using `pjit`. Please note that, as opposed to `pmap`, `pjit` allows to seemlessly combine data and model parallel execution as well." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "executionInfo": { "elapsed": 68, "status": "ok", "timestamp": 1671547961887, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "wd_3OqM0NoXW" }, "outputs": [], "source": [ "from jax.sharding import Mesh\n", "from jax.sharding import PartitionSpec\n", "from jax.experimental.pjit import pjit\n", "\n", "\n", "def setup_data_parallel_mesh():\n", " global_mesh = Mesh(np.asarray(jax.devices(), dtype=object), ['data'])\n", " jax.experimental.maps.thread_resources.env = (\n", " jax.experimental.maps.ResourceEnv(physical_mesh=global_mesh, loops=()))\n", " \n", "setup_data_parallel_mesh()" ] }, { "cell_type": "markdown", "metadata": { "id": "nKnR5yfpNpvw" }, "source": [ "# Custom-loop\n", "The following code uses data-parallelism in the train loop. Through the `use_pjit` keyword argument we can deactivate this parallelism. We'll use this feature later to benchmark the impact of parallelism." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "executionInfo": { "elapsed": 74, "status": "ok", "timestamp": 1671547962097, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "EXiLgFO_R8_c" }, "outputs": [], "source": [ "def fit(\n", " data: Tuple[Array, Array],\n", " init_params: Array,\n", " stepsize: float = 0.0,\n", " linesearch: str = 'zoom',\n", " use_pjit: bool = False,\n", ") -> Tuple[np.ndarray, np.ndarray, float]:\n", " \"\"\"Fits a multi-class logistic regression model for demonstration purposes.\n", "\n", " Args:\n", " data: A tuple `(X, y)` with the training covariates and categorical labels,\n", " respectively.\n", " init_params: The initial value of the model's weights.\n", " stepsize: The stepsize to use for the solver. If set to `0`, linesearch will\n", " be used instead.\n", " linesearch: The linesearch algorithm to use. If `stepsize > 0`, linsearch\n", " will be disabled.\n", " use_pjit: Whether to distribute the computation across replicas or use only\n", " the first device available.\n", "\n", " Returns:\n", " The per-step errors and runtimes, as well as the JIT-compile time for the\n", " solver's `update` function.\n", " \"\"\"\n", " # When using `pjit` to distribute the computation across devices, it is not\n", " # necessary to override the `value_and_grad` of `fun` (though it is supported\n", " # if desired for other reasons, e.g. gradient clipping).\n", " solver = jaxopt.LBFGS(fun=jaxopt.objective.multiclass_logreg,\n", " stepsize=stepsize,\n", " linesearch=linesearch)\n", " # Apply the `jax.pmap` transform to the function to be computed in a\n", " # distributed manner (the solver's `update` method in this case). Otherwise,\n", " # we JIT compile it.\n", " if use_pjit:\n", " update = pjit(\n", " solver.update,\n", " in_axis_resources=(None, None, PartitionSpec('data')),\n", " out_axis_resources=None)\n", " else:\n", " update = jax.jit(solver.update)\n", "\n", " # Initialize solver state.\n", " state = solver.init_state(init_params, data=data)\n", " params = init_params\n", " # When using `pjit` for data-parallel training, we do not need to explicitly\n", " # replicate model parameters across devices. Instead, replication is specified\n", " # via the `in_axes_resources` argument of the `pjit` transformation.\n", "\n", " # Finally, since in this demo we are *not* using mini-batches, it pays off to\n", " # transfer data to the device beforehand. Otherwise, host-to-device transfers\n", " # occur in each update. This is true regardless of whether we use distributed\n", " # or single-device computation.\n", " if use_pjit: # Shards data and moves it to device,\n", " data = pjit(\n", " lambda X, y: (X, y),\n", " in_axis_resources=PartitionSpec('data'),\n", " out_axis_resources=PartitionSpec('data'))(*data)\n", " else: # Just move data to device.\n", " data = jax.tree_map(jax.device_put, data)\n", "\n", " # Pre-compiles update, preventing it from affecting step times.\n", " tic = time.time()\n", " _ = update(params, state, data)\n", " compile_time = time.time() - tic\n", "\n", " outer_tic = time.time()\n", "\n", " step_times = np.zeros(MAXITER)\n", " errors = np.zeros(MAXITER)\n", " for it in range(MAXITER):\n", " tic = time.time()\n", " params, state = update(params, state, data)\n", " jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n", " step_times[it] = time.time() - tic\n", " errors[it] = state.error.item()\n", "\n", " print(\n", " f'Total time elapsed with {linesearch} linesearch and pjit = {use_pjit}:',\n", " round(time.time() - outer_tic, 2), 'seconds.')\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "HZjIR0q9NsLg" }, "source": [ "# Boilerplate\n", "Creates dataset, calls `fit` with and without `jax.pjit`, makes figures." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "executionInfo": { "elapsed": 74, "status": "ok", "timestamp": 1671547962305, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "4DTLFIPr4zfG" }, "outputs": [], "source": [ "def run():\n", " \"\"\"Boilerplate to run the demo experiment.\"\"\"\n", " data = datasets.make_classification(n_samples=NUM_SAMPLES,\n", " n_features=NUM_FEATURES,\n", " n_classes=NUM_CLASSES,\n", " n_informative=50,\n", " random_state=0)\n", " init_params = jnp.zeros([NUM_FEATURES, NUM_CLASSES])\n", "\n", " errors, step_times, compile_time = {}, {}, {}\n", "\n", " for use_pjit in (True, False):\n", " exp_name: str = f\"{'with' if use_pjit else 'without'}_pjit\"\n", " _errors, _step_times, _compile_time = fit(data=data,\n", " init_params=init_params,\n", " stepsize=STEPSIZE,\n", " linesearch=LINESEARCH,\n", " use_pjit=use_pjit)\n", " errors[exp_name] = _errors\n", " step_times[exp_name] = _step_times\n", " compile_time[exp_name] = _compile_time\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pjit in (True, False):\n", " exp_name: str = f\"{'with' if use_pjit else 'without'}_pjit\"\n", " plt.plot(jnp.arange(MAXITER), errors[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Gradient error', fontsize=16)\n", " plt.yscale('log')\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pjit in (True, False):\n", " exp_name: str = f\"{'with' if use_pjit else 'without'}_pjit\"\n", " plt.plot(jnp.arange(MAXITER), step_times[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Step time', fontsize=16)\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "fx1BGKFBNukq" }, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "height": 1000 }, "executionInfo": { "elapsed": 30301, "status": "ok", "timestamp": 1671547992729, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "haEIB6LV0ZA5", "outputId": "8eb9bc9f-c08f-48f0-dcf2-76eb1995654f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num_samples: 50000\n", "num_features: 784\n", "num_features: 10\n", "maxiter: 100\n", "stepsize: 0.0\n", "linesearch (ignored if `stepsize` > 0): zoom\n", "\n", "Total time elapsed with zoom linesearch and pjit = True: 3.79 seconds.\n", "Total time elapsed with zoom linesearch and pjit = False: 16.8 seconds.\n", "Average speed-up (ignoring compile): 9.57\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(\"num_samples:\", NUM_SAMPLES)\n", "print(\"num_features:\", NUM_FEATURES)\n", "print(\"num_features:\", NUM_CLASSES)\n", "print(\"maxiter:\", MAXITER)\n", "print(\"stepsize:\", STEPSIZE)\n", "print(\"linesearch (ignored if `stepsize` > 0):\", LINESEARCH)\n", "print()\n", "\n", "errors, step_times, compile_time = run()\n", "print('Average speed-up (ignoring compile):',\n", " round((step_times['without_pjit'] / step_times['with_pjit']).mean(), 2))" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "provenance": [ { "file_id": "18Ea9Vl8fRiCpya0DXitfklcf4kaKj_kN", "timestamp": 1671533866034 }, { "file_id": "1RaTFt691iv9n0PVyb3lcp9tMWCUje7bK", "timestamp": 1664566635253 } ] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 0 }