Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.

This implementation is strongly inspired by the Pytorch code snippets in [3].

A similar model called “implicit deep learning” is also proposed in [2].

In practice BatchNormalization and initialization of weights in convolutions are important to ensure convergence.

[1] Bai, S., Kolter, J.Z. and Koltun, V., 2019. Deep Equilibrium Models. Advances in Neural Information Processing Systems, 32, pp.690-701.

[2] El Ghaoui, L., Gu, F., Travacca, B., Askari, A. and Tsai, A., 2021. Implicit deep learning. SIAM Journal on Mathematics of Data Science, 3(3), pp.930-958.

[3] http://implicit-layers-tutorial.org/deep_equilibrium_models/

from functools import partial
from typing import Any, Mapping, Tuple, Callable

from absl import app
from absl import flags

import flax
from flax import linen as nn

import jax
import jax.numpy as jnp
from jax.tree_util import tree_structure

import jaxopt
from jaxopt import loss
from jaxopt import OptaxSolver
from jaxopt import FixedPointIteration
from jaxopt import AndersonAcceleration
from jaxopt.linear_solve import solve_gmres, solve_normal_cg
from jaxopt.tree_util import tree_add, tree_sub, tree_l2_norm

import optax

import tensorflow_datasets as tfds
import tensorflow as tf
from collections import namedtuple


dataset_names = [
    "mnist", "kmnist", "emnist", "fashion_mnist", "cifar10", "cifar100"
]

# training hyper-parameters
flags.DEFINE_float("l2reg", 0., "L2 regularization.")
flags.DEFINE_float("learning_rate", 0.001, "Learning rate.")
flags.DEFINE_integer("maxiter", 100, "Maximum number of iterations.")
flags.DEFINE_enum("dataset", "mnist", dataset_names, "Dataset to train on.")
flags.DEFINE_integer("net_width", 1, "Multiplicator of neural network width.")
flags.DEFINE_integer("evaluation_frequency", 1,
                     "Number of iterations between two evaluation measures.")
flags.DEFINE_integer("train_batch_size", 256, "Batch size at train time.")
flags.DEFINE_integer("test_batch_size", 1024, "Batch size at test time.")


solvers = ["normal_cg", "gmres", "anderson"]
flags.DEFINE_enum("backward_solver", "normal_cg", solvers,
                  "Solver of linear sytem in implicit differentiation.")

# anderson acceleration parameters
flags.DEFINE_enum("forward_solver", "anderson", ["anderson", "fixed_point"],
                  "Whether to use Anderson acceleration.")
flags.DEFINE_integer("forward_maxiter", 20, "Number of fixed point iterations.")
flags.DEFINE_float("forward_tol", 1e-2, "Tolerance in fixed point iterations.")
flags.DEFINE_integer("anderson_history_size", 5,
                     "Size of history in Anderson updates.")
flags.DEFINE_float("anderson_ridge", 1e-4,
                   "Ridge regularization in Anderson updates.")

FLAGS = flags.FLAGS


def load_dataset(split, *, is_training, batch_size):
  """Loads the dataset as a generator of batches."""
  ds, ds_info = tfds.load(f"{FLAGS.dataset}:3.*.*", split=split,
                          as_supervised=True, with_info=True)
  ds = ds.cache().repeat()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  return iter(tfds.as_numpy(ds)), ds_info


class ResNetBlock(nn.Module):
  """ResNet block."""

  channels: int
  channels_bottleneck: int
  num_groups: int = 8
  kernel_size: Tuple[int, int] = (3, 3)
  use_bias: bool = False
  act: Callable = nn.relu

  @nn.compact
  def __call__(self, z, x):
    # Note that stddev=0.01 is important to avoid divergence.
    # Empirically it ensures that fixed point iterations converge.
    y = z
    y = nn.Conv(features=self.channels_bottleneck, kernel_size=self.kernel_size,
                padding="SAME", use_bias=self.use_bias,
                kernel_init=jax.nn.initializers.normal(stddev=0.01))(y)
    y = self.act(y)
    y = nn.GroupNorm(num_groups=self.num_groups)(y)
    y = nn.Conv(features=self.channels, kernel_size=self.kernel_size,
                padding="SAME", use_bias=self.use_bias,
                kernel_init=jax.nn.initializers.normal(stddev=0.01))(y)
    y = y + x
    y = nn.GroupNorm(num_groups=self.num_groups)(y)
    y = y + z
    y = self.act(y)
    y = nn.GroupNorm(num_groups=self.num_groups)(y)
    return y


class DEQFixedPoint(nn.Module):
  """Batched computation of ``block`` using ``fixed_point_solver``."""

  block: Any  # nn.Module
  fixed_point_solver: Any  # AndersonAcceleration or FixedPointIteration

  @nn.compact
  def __call__(self, x):
    # shape of a single example
    init = lambda rng, x: self.block.init(rng, x[0], x[0])
    block_params = self.param("block_params", init, x)

    def block_apply(z, x, block_params):
      return self.block.apply(block_params, z, x)

    solver = self.fixed_point_solver(fixed_point_fun=block_apply)
    def batch_run(x, block_params):
      return solver.run(x, x, block_params)[0]

    # We use vmap since we want to compute the fixed point separately for each
    # example in the batch.
    return jax.vmap(batch_run, in_axes=(0,None), out_axes=0)(x, block_params)


class FullDEQ(nn.Module):
  """DEQ model."""

  num_classes: int
  channels: int
  channels_bottleneck: int
  fixed_point_solver: Callable

  @nn.compact
  def __call__(self, x, train):
    # Note that x is a batch of examples:
    # because of BatchNorm we cannot define the forward pass in the network for
    # a single example.
    x = nn.Conv(features=self.channels, kernel_size=(3,3), use_bias=True,
                padding="SAME")(x)
    x = nn.BatchNorm(use_running_average=not train, momentum=0.9,
                     epsilon=1e-5)(x)
    block = ResNetBlock(self.channels, self.channels_bottleneck)
    deq_fixed_point = DEQFixedPoint(block, self.fixed_point_solver)
    x = deq_fixed_point(x)
    x = nn.BatchNorm(use_running_average=not train, momentum=0.9,
                     epsilon=1e-5)(x)
    x = nn.avg_pool(x, window_shape=(8,8), padding="SAME")
    x = x.reshape(x.shape[:-3] + (-1,))  # flatten
    x = nn.Dense(self.num_classes)(x)
    return x


# For completeness, we also allow Anderson acceleration for solving
# the implicit differentiation linear system occurring in the backward pass.
def solve_linear_system_fixed_point(matvec, v):
  """Solve linear system matvec(u) = v.

  The solution u* of the system is the fixed point of:
    T(u) = matvec(u) + u - v
  """
  def fixed_point_fun(u):
    d_1_T_transpose_u = tree_add(matvec(u), u)
    return tree_sub(d_1_T_transpose_u, v)

  aa = AndersonAcceleration(fixed_point_fun,
                            history_size=FLAGS.anderson_history_size, tol=1e-2,
                            ridge=FLAGS.anderson_ridge, maxiter=20)
  return aa.run(v)[0]


def main(argv):
  del argv

  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  tf.config.experimental.set_visible_devices([], 'GPU')

  # Solver used for implicit differentiation (backward pass).
  if FLAGS.backward_solver == "normal_cg":
    implicit_solver = partial(solve_normal_cg, tol=1e-2, maxiter=20)
  elif FLAGS.backward_solver == "gmres":
    implicit_solver = partial(solve_gmres, tol=1e-2, maxiter=20)
  elif FLAGS.backward_solver == "anderson":
    implicit_solver = solve_linear_system_fixed_point

  # Solver used for fixed point resolution (forward pass).
  if FLAGS.forward_solver == "anderson":
    fixed_point_solver = partial(AndersonAcceleration,
                                 history_size=FLAGS.anderson_history_size,
                                 ridge=FLAGS.anderson_ridge,
                                 maxiter=FLAGS.forward_maxiter,
                                 tol=FLAGS.forward_tol, implicit_diff=True,
                                 implicit_diff_solve=implicit_solver)
  else:
    fixed_point_solver = partial(FixedPointIteration,
                                 maxiter=FLAGS.forward_maxiter,
                                 tol=FLAGS.forward_tol, implicit_diff=True,
                                 implicit_diff_solve=implicit_solver)

  train_ds, ds_info = load_dataset("train", is_training=True,
                                    batch_size=FLAGS.train_batch_size)
  test_ds, _ = load_dataset("test", is_training=False,
                            batch_size=FLAGS.test_batch_size)
  input_shape = (1,) + ds_info.features["image"].shape
  num_classes = ds_info.features["label"].num_classes

  net = FullDEQ(num_classes, 3 * 8 * FLAGS.net_width, 4 * 8 * FLAGS.net_width,
                fixed_point_solver)

  def predict(all_params, images, train=False):
    """Forward pass in the network on the images."""
    x = images.astype(jnp.float32) / 255.
    mutable = ["batch_stats"] if train else False
    return net.apply(all_params, x, train=train, mutable=mutable)

  logistic_loss = jax.vmap(loss.multiclass_logistic_loss)

  def loss_from_logits(params, l2reg, logits, labels):
    sqnorm = tree_l2_norm(params, squared=True)
    mean_loss = jnp.mean(logistic_loss(labels, logits))
    return mean_loss + 0.5 * l2reg * sqnorm

  @jax.jit
  def accuracy_and_loss(params, l2reg, data, aux):
    all_vars = {"params": params, "batch_stats": aux}
    images, labels = data
    logits = predict(all_vars, images)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    loss = loss_from_logits(params, l2reg, logits, labels)
    return accuracy, loss

  @jax.jit
  def loss_fun(params, l2reg, data, aux):
    all_vars = {"params": params, "batch_stats": aux}
    images, labels = data
    logits, net_state = predict(all_vars, images, train=True)
    loss = loss_from_logits(params, l2reg, logits, labels)
    return loss, net_state["batch_stats"]

  def print_evaluation(params, state, l2reg, data, aux):
    # We don't need `data` because we evaluate on the test set.
    del data

    if state.iter_num % FLAGS.evaluation_frequency == 0:
      # Periodically evaluate classification accuracy on test set.
      accuracy, loss = accuracy_and_loss(params, l2reg, data=next(test_ds),
                                         aux=aux)

      print(f"[Step {state.iter_num}] "
            f"Test accuracy: {accuracy:.3f} "
            f"Test loss: {loss:.3f}.")

    return params, state

  # Initialize solver and parameters.
  solver = OptaxSolver(opt=optax.adam(FLAGS.learning_rate),
                       fun=loss_fun,
                       maxiter=FLAGS.maxiter,
                       pre_update=print_evaluation,
                       has_aux=True)

  rng = jax.random.PRNGKey(0)
  init_vars = net.init(rng, jnp.ones(input_shape), train=True)
  params = init_vars["params"]
  batch_stats = init_vars["batch_stats"]
  state = solver.init_state(params)

  for iternum in range(solver.maxiter):
    params, state = solver.update(params=params, state=state,
                                  l2reg=FLAGS.l2reg, data=next(train_ds),
                                  aux=batch_stats)
    batch_stats = state.aux


if __name__ == "__main__":
  app.run(main)

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery