JAXopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

  • Hardware accelerated: our implementations run on GPU and TPU, in addition to CPU.

  • Batchable: multiple instances of the same optimization problem can be automatically vectorized using JAX’s vmap.

  • Differentiable: optimization problem solutions can be differentiated with respect to their inputs either implicitly or via autodiff of unrolled algorithm iterations.

Installation

To install the latest release of JAXopt, use the following command:

pip install jaxopt

To install the development version, use the following command instead:

pip install git+https://github.com/google/jaxopt

Alternatively, it can be be installed from sources with the following command:

python setup.py install

Support

If you are having issues, please let us know by filing an issue on our issue tracker.

License

JAXopt is licensed under the Apache 2.0 License.

Citing

If this software is useful for you, please consider citing the paper that describes its implicit differentiation framework:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy
   and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian
   and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Indices and tables