.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep_learning/plot_sgd_solvers.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py: Comparison of different SGD algorithms. ======================================= The purpose of this example is to illustrate the power of adaptive stepsize algorithms. We chose a classification task with classes that are easily separable and no regularization (interpolating regime). We compare: * SGD with Polyak stepsize (theoretical guarantees require interpolation) * SGD with Armijo line search (theoretical guarantees require interpolation) * SGD with constant stepsize * RMSprop The reported ``training loss`` is an estimation of the true training loss based on the current minibatch. This experiment was conducted without momentum, with popular default values for learning rate. .. GENERATED FROM PYTHON SOURCE LINES 38-220 .. image-sg:: /auto_examples/deep_learning/images/sphx_glr_plot_sgd_solvers_001.png :alt: Classification on fashion_mnist :srcset: /auto_examples/deep_learning/images/sphx_glr_plot_sgd_solvers_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /Users/mblondel/envs/work/lib/python3.9/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead. jax.tree_util.register_keypaths(data_clz, keypaths) /Users/mblondel/envs/work/lib/python3.9/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead. jax.tree_util.register_keypaths(data_clz, keypaths) [Polyak SGD] ( 0), error=0.272, loss=2.301, stepsize=1.000000 [Polyak SGD] ( 20), error=1.919, loss=1.482, stepsize=0.402169 [Polyak SGD] ( 40), error=4.585, loss=1.666, stepsize=0.079269 [Polyak SGD] ( 60), error=1.422, loss=0.891, stepsize=0.440547 [Polyak SGD] ( 80), error=2.356, loss=1.200, stepsize=0.216168 [Polyak SGD] ( 100), error=3.512, loss=1.183, stepsize=0.095881 101 iterations of Polyak SGD took 15.15s [Armijo SGD] ( 0), error=0.227, loss=2.313, stepsize=1.000000 [Armijo SGD] ( 20), error=2.221, loss=1.331, stepsize=0.065971 [Armijo SGD] ( 40), error=2.062, loss=1.080, stepsize=0.033777 [Armijo SGD] ( 60), error=3.499, loss=1.059, stepsize=0.027022 [Armijo SGD] ( 80), error=1.456, loss=0.863, stepsize=0.021617 [Armijo SGD] ( 100), error=1.689, loss=0.819, stepsize=0.021617 101 iterations of Armijo SGD took 15.80s [RMSProp] ( 0), error=0.223, loss=2.299 [RMSProp] ( 20), error=0.869, loss=1.913 [RMSProp] ( 40), error=1.151, loss=1.453 [RMSProp] ( 60), error=1.115, loss=1.211 [RMSProp] ( 80), error=1.627, loss=1.088 [RMSProp] ( 100), error=1.581, loss=1.002 101 iterations of RMSProp took 13.23s [SGD] ( 0), error=0.190, loss=2.306 [SGD] ( 20), error=0.384, loss=2.236 [SGD] ( 40), error=2.765, loss=1.878 [SGD] ( 60), error=1.222, loss=1.092 [SGD] ( 80), error=1.637, loss=1.146 [SGD] ( 100), error=4.392, loss=1.671 101 iterations of SGD took 12.89s | .. code-block:: default from absl import flags import logging import sys from timeit import default_timer as timer import jax import jax.numpy as jnp from jaxopt import loss from jaxopt import ArmijoSGD from jaxopt import PolyakSGD from jaxopt import OptaxSolver from jaxopt.tree_util import tree_l2_norm, tree_sub import optax from flax import linen as nn import numpy as onp import matplotlib.pyplot as plt from scipy.signal import convolve import tensorflow_datasets as tfds import tensorflow as tf dataset_names = [ "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100" ] flags.DEFINE_float("rmsprop_stepsize", 1e-3, "RMSProp stepsize") flags.DEFINE_float("gd_stepsize", 1e-1, "Gradient descent stepsize") flags.DEFINE_float("aggressiveness", 0.9, "Aggressiveness of line search in armijo-sgd.") flags.DEFINE_float("max_stepsize", 1.0, "Maximum step size (used in polyak-sgd, armijo-sgd).") flags.DEFINE_integer("maxiter", 100, "Maximum number of iterations.") flags.DEFINE_integer("batch_size", 128, "Batch-size.") flags.DEFINE_integer("net_width", 4, "Width mutiplicator of network.") flags.DEFINE_enum("dataset", "fashion_mnist", dataset_names, "Dataset to train on.") jax.config.update("jax_platform_name", "cpu") plt.rcParams.update({'font.size': 13}) # Load dataset. def load_dataset(dataset, batch_size): """Loads the dataset as a generator of batches.""" train_ds, ds_info = tfds.load(f"{dataset}:3.*.*", split="train", as_supervised=True, with_info=True) train_ds = train_ds.repeat() train_ds = train_ds.shuffle(10 * batch_size, seed=0) train_ds = train_ds.batch(batch_size) return tfds.as_numpy(train_ds), ds_info # Define NN architecture. class CNN(nn.Module): """A simple CNN model.""" num_classes: int net_width: int @nn.compact def __call__(self, x): x = nn.Conv(features=self.net_width, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=self.net_width*2, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=self.net_width*4)(x) x = nn.relu(x) x = nn.Dense(features=self.num_classes)(x) return x def main(argv): # manual flags parsing to avoid conflicts between absl.app.run and sphinx-gallery flags.FLAGS(argv) FLAGS = flags.FLAGS # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') train_ds, ds_info = load_dataset(FLAGS.dataset, FLAGS.batch_size) # Initialize parameters. input_shape = (1,) + ds_info.features["image"].shape rng = jax.random.PRNGKey(0) num_classes = ds_info.features["label"].num_classes net = CNN(num_classes, FLAGS.net_width) logistic_loss = jax.vmap(loss.multiclass_logistic_loss) def loss_fun(params, data): """Compute the loss of the network.""" inputs, labels = data x = inputs.astype(jnp.float32) / 255. logits = net.apply({"params": params}, x) loss_value = jnp.mean(logistic_loss(labels, logits)) return loss_value def run_all(opt, name): """Train the NN on dataset with ``opt``.""" # Use same starting hyper-parameters for fair evaluation. params = net.init(rng, jnp.zeros(input_shape))["params"] # Use same batch order for fair evaluation ds_iterator = iter(train_ds) state = opt.init_state(params, data=next(ds_iterator)) @jax.jit def jitted_update(params, state, batch): return opt.update(params, state, batch) value_and_grad_loss = jax.value_and_grad(loss_fun) errors, steps, losses = [], [], [] start = timer() for iter_num, batch in zip(range(opt.maxiter+1), ds_iterator): loss, _ = value_and_grad_loss(params, data=batch) params, state = jitted_update(params, state, batch) if "Armijo" in name or "Polyak" in name: steps.append(state.stepsize) errors.append(state.error) losses.append(loss) if iter_num % 20 == 0: if "Armijo" in name or "Polyak" in name: print(f"[{name}] ({iter_num:4d}), error={errors[-1]:.3f}, loss={losses[-1]:.3f}, stepsize={steps[-1]:.6f}") else: print(f"[{name}] ({iter_num:4d}), error={errors[-1]:.3f}, loss={losses[-1]:.3f}") duration = timer() - start print(f"{opt.maxiter+1} iterations of {name} took {duration:.2f}s\n") losses = convolve(losses, onp.ones(5) / 5, mode='valid', method='direct') # smooth the curve return errors[:len(losses)], steps[:len(losses)], losses polyak = PolyakSGD(fun=loss_fun, max_stepsize=FLAGS.max_stepsize, maxiter=FLAGS.maxiter) armijo = ArmijoSGD(fun=loss_fun, aggressiveness=FLAGS.aggressiveness, reset_option='goldstein', max_stepsize=FLAGS.max_stepsize, maxls=20, maxiter=FLAGS.maxiter) rmsprop = OptaxSolver(fun=loss_fun, maxiter=FLAGS.maxiter, opt=optax.rmsprop(learning_rate=FLAGS.rmsprop_stepsize)) gd = OptaxSolver(fun=loss_fun, maxiter=FLAGS.maxiter, opt=optax.sgd(learning_rate=FLAGS.gd_stepsize, nesterov=False)) errors_data, losses_data, stepsize_data = [], [], [] for opt, name in zip([polyak, armijo, rmsprop, gd], ["Polyak SGD", "Armijo SGD", "RMSProp", "SGD"]): errors, steps, losses = run_all(opt, name) errors_data.append(errors) losses_data.append(losses) stepsize_data.append(steps) fig = plt.figure(figsize=(18,12)) fig.suptitle(f'Classification on {FLAGS.dataset}') spec = fig.add_gridspec(ncols=2, nrows=2) spec_slices = [spec[0,:], spec[1:,:]] datas = stepsize_data, losses_data names = ['Stepsize', 'Training loss'] for spec_slice, single_data, name in zip(spec_slices, datas, names): polyak_data, armijo_data, rmsprop_data, gd_data = single_data ax = fig.add_subplot(spec_slice) ax.set_xlabel("Iter num") ax.set_ylabel(f"{name} (logscale)") iter_nums = jnp.arange(1, len(polyak_data)+1) ax.plot(iter_nums, polyak_data, ':', c='red', linewidth=1., label=f'Polyak max_stepsize={FLAGS.max_stepsize}') ax.plot(iter_nums, armijo_data, '--', c='green', linewidth=1., label=f'Armijo aggressiveness={FLAGS.aggressiveness}') if name != 'Stepsize': ax.plot(iter_nums, gd_data, '-', c='blue', linewidth=1., label=f'SGD stepsize={FLAGS.gd_stepsize:.3f}') ax.plot(iter_nums, rmsprop_data, '.-', c='orange', linewidth=1., label=f'RMSProp stepsize={FLAGS.rmsprop_stepsize:.3f}') ax.set_yscale('log') leg = ax.legend(loc='upper right', framealpha=1.) spec.tight_layout(fig) plt.show() if __name__ == "__main__": main(sys.argv) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 1.032 seconds) .. _sphx_glr_download_auto_examples_deep_learning_plot_sgd_solvers.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sgd_solvers.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sgd_solvers.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_