.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/deep_learning/haiku_image_classif.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_deep_learning_haiku_image_classif.py: Image classification example with Haiku and JAXopt. =================================================== .. GENERATED FROM PYTHON SOURCE LINES 20-202 .. code-block:: default import functools from absl import app from absl import flags import haiku as hk import jax import jax.numpy as jnp from jaxopt import loss from jaxopt import ArmijoSGD from jaxopt import OptaxSolver from jaxopt import PolyakSGD from jaxopt import tree_util import optax import tensorflow_datasets as tfds import tensorflow as tf dataset_names = [ "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100" ] flags.DEFINE_float("l2reg", 1e-4, "L2 regularization.") flags.DEFINE_float("learning_rate", 0.001, "Learning rate (used in adam).") flags.DEFINE_bool("manual_loop", False, "Whether to use a manual training loop.") flags.DEFINE_integer("epochs", 5, "Number of passes over the dataset.") flags.DEFINE_float("max_stepsize", 0.1, "Maximum step size (used in polyak-sgd, armijo-sgd).") flags.DEFINE_float("aggressiveness", 0.5, "Aggressiveness of line search in armijo-sgd.") flags.DEFINE_float("momentum", 0.9, "Momentum strength (used in adam, polyak-sgd, armijo-sgd).") flags.DEFINE_enum("dataset", "mnist", dataset_names, "Dataset to train on.") flags.DEFINE_enum("model", "cnn", ["cnn", "mlp"], "Model architecture.") flags.DEFINE_enum("solver", "adam", ["adam", "sgd", "polyak-sgd", "armijo-sgd"], "Solver to use.") flags.DEFINE_integer("train_batch_size", 256, "Batch size at train time.") flags.DEFINE_integer("test_batch_size", 1024, "Batch size at test time.") FLAGS = flags.FLAGS def load_dataset(split, *, is_training, batch_size): """Loads the dataset as a generator of batches.""" version = 3 ds, ds_info = tfds.load( f"{FLAGS.dataset}:{version}.*.*", as_supervised=True, # remove useless keys split=split, with_info=True) ds = ds.cache().repeat() if is_training: ds = ds.shuffle(10 * batch_size, seed=0) ds = ds.batch(batch_size) return iter(tfds.as_numpy(ds)), ds_info def net_fun(batch, num_classes): """Create model.""" x = batch[0].astype(jnp.float32) / 255. if FLAGS.model == "cnn": model = hk.Sequential([ hk.Conv2D(output_channels=32, kernel_shape=(3, 3), padding="SAME"), jax.nn.relu, hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), hk.Conv2D(output_channels=64, kernel_shape=(3, 3), padding="SAME"), jax.nn.relu, hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), hk.Flatten(), hk.Linear(256), jax.nn.relu, hk.Linear(num_classes), ]) else: model = hk.Sequential([ hk.Flatten(), hk.Linear(300), jax.nn.relu, hk.Linear(100), jax.nn.relu, hk.Linear(num_classes), ]) y = model(x) return y def main(argv): del argv # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') train_ds, ds_info = load_dataset("train", is_training=True, batch_size=FLAGS.train_batch_size) test_ds, _ = load_dataset("test", is_training=False, batch_size=FLAGS.test_batch_size) num_classes = ds_info.features["label"].num_classes maxiter = FLAGS.epochs * ds_info.splits['train'].num_examples // FLAGS.train_batch_size # Initialize parameters. net = functools.partial(net_fun, num_classes=num_classes) net = hk.without_apply_rng(hk.transform(net)) logistic_loss = jax.vmap(loss.multiclass_logistic_loss) def loss_fun(params, l2reg, data): """Compute the loss of the network.""" logits = net.apply(params, data) _, labels = data sqnorm = tree_util.tree_l2_norm(params, squared=True) loss_value = jnp.mean(logistic_loss(labels, logits)) return loss_value + 0.5 * l2reg * sqnorm @jax.jit def accuracy(params, data): _, labels = data predictions = net.apply(params, data) return jnp.mean(jnp.argmax(predictions, axis=-1) == labels) def print_accuracy(params, state, *args, **kwargs): if state.iter_num % 10 == 0: # Periodically evaluate classification accuracy on test set. test_accuracy = accuracy(params, next(test_ds)) test_accuracy = jax.device_get(test_accuracy) print(f"[Step {state.iter_num}] Test accuracy: {test_accuracy:.3f}.") return params, state # Initialize solver. if FLAGS.solver == "adam": # Equivalent to: # opt = optax.chain(optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), # optax.scale(-FLAGS.learning_rate)) opt = optax.adam(FLAGS.learning_rate) solver = OptaxSolver(opt=opt, fun=loss_fun, maxiter=maxiter, pre_update=print_accuracy) elif FLAGS.solver == "sgd": opt = optax.sgd(FLAGS.learning_rate, FLAGS.momentum) solver = OptaxSolver(opt=opt, fun=loss_fun, maxiter=maxiter, pre_update=print_accuracy) elif FLAGS.solver == "polyak-sgd": solver = PolyakSGD(fun=loss_fun, maxiter=maxiter, momentum=FLAGS.momentum, max_stepsize=FLAGS.max_stepsize, pre_update=print_accuracy) elif FLAGS.solver == "armijo-sgd": solver = ArmijoSGD(fun=loss_fun, maxiter=maxiter, aggressiveness=FLAGS.aggressiveness, momentum=FLAGS.momentum, max_stepsize=FLAGS.max_stepsize, pre_update=print_accuracy) else: raise ValueError("Unknown solver: %s" % FLAGS.solver) params = net.init(jax.random.PRNGKey(42), next(train_ds)) # Run training loop. # In JAXopt, stochastic solvers can be run either using a manual for loop or # using `run_iterator`. We include both here for demonstration purpose. if FLAGS.manual_loop: state = solver.init_state(params) for _ in range(maxiter): params, state = solver.update(params=params, state=state, l2reg=FLAGS.l2reg, data=next(train_ds)) else: params, state = solver.run_iterator( init_params=params, iterator=train_ds, l2reg=FLAGS.l2reg) print_accuracy(params=params, state=state) if __name__ == "__main__": app.run(main) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_deep_learning_haiku_image_classif.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: haiku_image_classif.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: haiku_image_classif.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_