Implicit differentiation of lasso.

from absl import app
from absl import flags

import jax
import jax.numpy as jnp

from jaxopt import BlockCoordinateDescent
from jaxopt import objective
from jaxopt import OptaxSolver
from jaxopt import prox
from jaxopt import ProximalGradient
import optax

from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing

flags.DEFINE_bool("unrolling", False, "Whether to use unrolling.")
flags.DEFINE_string("solver", "bcd", "Solver to use (bcd or pg).")
FLAGS = flags.FLAGS


def outer_objective(theta, init_inner, data):
  """Validation loss."""
  X_tr, X_val, y_tr, y_val = data
  # We use the bijective mapping lam = jnp.exp(theta) to ensure positivity.
  lam = jnp.exp(theta)

  if FLAGS.solver == "pg":
    solver = ProximalGradient(
        fun=objective.least_squares,
        prox=prox.prox_lasso,
        implicit_diff=not FLAGS.unrolling,
        maxiter=500)
  elif FLAGS.solver == "bcd":
    solver = BlockCoordinateDescent(
        fun=objective.least_squares,
        block_prox=prox.prox_lasso,
        implicit_diff=not FLAGS.unrolling,
        maxiter=500)
  else:
    raise ValueError("Unknown solver.")

  # The format is run(init_params, hyperparams_prox, *args, **kwargs)
  # where *args and **kwargs are passed to `fun`.
  w_fit = solver.run(init_inner, lam, (X_tr, y_tr)).params

  y_pred = jnp.dot(X_val, w_fit)
  loss_value = jnp.mean((y_pred - y_val) ** 2)

  # We return w_fit as auxiliary data.
  # Auxiliary data is stored in the optimizer state (see below).
  return loss_value, w_fit


def main(argv):
  del argv

  print("Solver:", FLAGS.solver)
  print("Unrolling:", FLAGS.unrolling)

  # Prepare data.
  X, y = datasets.load_boston(return_X_y=True)
  X = preprocessing.normalize(X)
  # data = (X_tr, X_val, y_tr, y_val)
  data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0)

  # Initialize solver.
  solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True)
  theta = 1.0
  init_w = jnp.zeros(X.shape[1])
  state = solver.init_state(theta, init_inner=init_w, data=data)

  # Run outer loop.
  for _ in range(10):
    theta, state = solver.update(params=theta, state=state, init_inner=init_w,
                                 data=data)
    # The auxiliary data returned by the outer loss is stored in the state.
    init_w = state.aux
    print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.")

if __name__ == "__main__":
  app.run(main)

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery