.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/implicit_diff/lasso_implicit_diff.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_implicit_diff_lasso_implicit_diff.py: Implicit differentiation of lasso. ================================== .. GENERATED FROM PYTHON SOURCE LINES 19-103 .. code-block:: default from absl import app from absl import flags import jax import jax.numpy as jnp from jaxopt import BlockCoordinateDescent from jaxopt import objective from jaxopt import OptaxSolver from jaxopt import prox from jaxopt import ProximalGradient import optax from sklearn import datasets from sklearn import model_selection from sklearn import preprocessing flags.DEFINE_bool("unrolling", False, "Whether to use unrolling.") flags.DEFINE_string("solver", "bcd", "Solver to use (bcd or pg).") FLAGS = flags.FLAGS def outer_objective(theta, init_inner, data): """Validation loss.""" X_tr, X_val, y_tr, y_val = data # We use the bijective mapping lam = jnp.exp(theta) to ensure positivity. lam = jnp.exp(theta) if FLAGS.solver == "pg": solver = ProximalGradient( fun=objective.least_squares, prox=prox.prox_lasso, implicit_diff=not FLAGS.unrolling, maxiter=500) elif FLAGS.solver == "bcd": solver = BlockCoordinateDescent( fun=objective.least_squares, block_prox=prox.prox_lasso, implicit_diff=not FLAGS.unrolling, maxiter=500) else: raise ValueError("Unknown solver.") # The format is run(init_params, hyperparams_prox, *args, **kwargs) # where *args and **kwargs are passed to `fun`. w_fit = solver.run(init_inner, lam, (X_tr, y_tr)).params y_pred = jnp.dot(X_val, w_fit) loss_value = jnp.mean((y_pred - y_val) ** 2) # We return w_fit as auxiliary data. # Auxiliary data is stored in the optimizer state (see below). return loss_value, w_fit def main(argv): del argv print("Solver:", FLAGS.solver) print("Unrolling:", FLAGS.unrolling) # Prepare data. X, y = datasets.load_boston(return_X_y=True) X = preprocessing.normalize(X) # data = (X_tr, X_val, y_tr, y_val) data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0) # Initialize solver. solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True) theta = 1.0 init_w = jnp.zeros(X.shape[1]) state = solver.init_state(theta, init_inner=init_w, data=data) # Run outer loop. for _ in range(10): theta, state = solver.update(params=theta, state=state, init_inner=init_w, data=data) # The auxiliary data returned by the outer loss is stored in the state. init_w = state.aux print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.") if __name__ == "__main__": app.run(main) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_implicit_diff_lasso_implicit_diff.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: lasso_implicit_diff.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: lasso_implicit_diff.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_