Note
Click here to download the full example code
Image classification example with Haiku and JAXopt.
import functools
from absl import app
from absl import flags
import haiku as hk
import jax
import jax.numpy as jnp
from jaxopt import loss
from jaxopt import ArmijoSGD
from jaxopt import OptaxSolver
from jaxopt import PolyakSGD
from jaxopt import tree_util
import optax
import tensorflow_datasets as tfds
import tensorflow as tf
dataset_names = [
"mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"
]
flags.DEFINE_float("l2reg", 1e-4, "L2 regularization.")
flags.DEFINE_float("learning_rate", 0.001, "Learning rate (used in adam).")
flags.DEFINE_bool("manual_loop", False, "Whether to use a manual training loop.")
flags.DEFINE_integer("epochs", 5, "Number of passes over the dataset.")
flags.DEFINE_float("max_stepsize", 0.1, "Maximum step size (used in polyak-sgd, armijo-sgd).")
flags.DEFINE_float("aggressiveness", 0.5, "Aggressiveness of line search in armijo-sgd.")
flags.DEFINE_float("momentum", 0.9, "Momentum strength (used in adam, polyak-sgd, armijo-sgd).")
flags.DEFINE_enum("dataset", "mnist", dataset_names, "Dataset to train on.")
flags.DEFINE_enum("model", "cnn", ["cnn", "mlp"], "Model architecture.")
flags.DEFINE_enum("solver", "adam", ["adam", "sgd", "polyak-sgd", "armijo-sgd"], "Solver to use.")
flags.DEFINE_integer("train_batch_size", 256, "Batch size at train time.")
flags.DEFINE_integer("test_batch_size", 1024, "Batch size at test time.")
FLAGS = flags.FLAGS
def load_dataset(split, *, is_training, batch_size):
"""Loads the dataset as a generator of batches."""
version = 3
ds, ds_info = tfds.load(
f"{FLAGS.dataset}:{version}.*.*",
as_supervised=True, # remove useless keys
split=split,
with_info=True)
ds = ds.cache().repeat()
if is_training:
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds)), ds_info
def net_fun(batch, num_classes):
"""Create model."""
x = batch[0].astype(jnp.float32) / 255.
if FLAGS.model == "cnn":
model = hk.Sequential([
hk.Conv2D(output_channels=32, kernel_shape=(3, 3), padding="SAME"),
jax.nn.relu,
hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"),
hk.Conv2D(output_channels=64, kernel_shape=(3, 3), padding="SAME"),
jax.nn.relu,
hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"),
hk.Flatten(),
hk.Linear(256),
jax.nn.relu,
hk.Linear(num_classes),
])
else:
model = hk.Sequential([
hk.Flatten(),
hk.Linear(300),
jax.nn.relu,
hk.Linear(100),
jax.nn.relu,
hk.Linear(num_classes),
])
y = model(x)
return y
def main(argv):
del argv
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')
train_ds, ds_info = load_dataset("train", is_training=True, batch_size=FLAGS.train_batch_size)
test_ds, _ = load_dataset("test", is_training=False, batch_size=FLAGS.test_batch_size)
num_classes = ds_info.features["label"].num_classes
maxiter = FLAGS.epochs * ds_info.splits['train'].num_examples // FLAGS.train_batch_size
# Initialize parameters.
net = functools.partial(net_fun, num_classes=num_classes)
net = hk.without_apply_rng(hk.transform(net))
logistic_loss = jax.vmap(loss.multiclass_logistic_loss)
def loss_fun(params, l2reg, data):
"""Compute the loss of the network."""
logits = net.apply(params, data)
_, labels = data
sqnorm = tree_util.tree_l2_norm(params, squared=True)
loss_value = jnp.mean(logistic_loss(labels, logits))
return loss_value + 0.5 * l2reg * sqnorm
@jax.jit
def accuracy(params, data):
_, labels = data
predictions = net.apply(params, data)
return jnp.mean(jnp.argmax(predictions, axis=-1) == labels)
def print_accuracy(params, state, *args, **kwargs):
if state.iter_num % 10 == 0:
# Periodically evaluate classification accuracy on test set.
test_accuracy = accuracy(params, next(test_ds))
test_accuracy = jax.device_get(test_accuracy)
print(f"[Step {state.iter_num}] Test accuracy: {test_accuracy:.3f}.")
return params, state
# Initialize solver.
if FLAGS.solver == "adam":
# Equivalent to:
# opt = optax.chain(optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
# optax.scale(-FLAGS.learning_rate))
opt = optax.adam(FLAGS.learning_rate)
solver = OptaxSolver(opt=opt, fun=loss_fun, maxiter=maxiter,
pre_update=print_accuracy)
elif FLAGS.solver == "sgd":
opt = optax.sgd(FLAGS.learning_rate, FLAGS.momentum)
solver = OptaxSolver(opt=opt, fun=loss_fun, maxiter=maxiter,
pre_update=print_accuracy)
elif FLAGS.solver == "polyak-sgd":
solver = PolyakSGD(fun=loss_fun, maxiter=maxiter,
momentum=FLAGS.momentum,
max_stepsize=FLAGS.max_stepsize,
pre_update=print_accuracy)
elif FLAGS.solver == "armijo-sgd":
solver = ArmijoSGD(fun=loss_fun, maxiter=maxiter,
aggressiveness=FLAGS.aggressiveness,
momentum=FLAGS.momentum,
max_stepsize=FLAGS.max_stepsize,
pre_update=print_accuracy)
else:
raise ValueError("Unknown solver: %s" % FLAGS.solver)
params = net.init(jax.random.PRNGKey(42), next(train_ds))
# Run training loop.
# In JAXopt, stochastic solvers can be run either using a manual for loop or
# using `run_iterator`. We include both here for demonstration purpose.
if FLAGS.manual_loop:
state = solver.init_state(params)
for _ in range(maxiter):
params, state = solver.update(params=params, state=state,
l2reg=FLAGS.l2reg,
data=next(train_ds))
else:
params, state = solver.run_iterator(
init_params=params, iterator=train_ds, l2reg=FLAGS.l2reg)
print_accuracy(params=params, state=state)
if __name__ == "__main__":
app.run(main)
Total running time of the script: ( 0 minutes 0.000 seconds)