API at a glance

Optimization

Unconstrained

jaxopt.BFGS(fun[, value_and_grad, has_aux, ...])

BFGS solver.

jaxopt.GradientDescent(fun[, prox, ...])

Gradient Descent solver.

jaxopt.LBFGS(fun[, value_and_grad, has_aux, ...])

LBFGS solver.

jaxopt.ScipyMinimize([method, dtype, jit, ...])

scipy.optimize.minimize wrapper

jaxopt.NonlinearCG(fun[, value_and_grad, ...])

Nonlinear conjugate gradient solver.

Constrained

jaxopt.LBFGSB(fun[, value_and_grad, ...])

L-BFGS-B solver.

jaxopt.MirrorDescent(fun, projection_grad, ...)

Mirror descent solver.

jaxopt.ProjectedGradient(fun, projection[, ...])

Projected gradient solver.

jaxopt.ScipyBoundedMinimize([method, dtype, ...])

scipy.optimize.minimize wrapper.

Quadratic programming

jaxopt.BoxCDQP([maxiter, tol, verbose, ...])

Coordinate descent solver for box-constrained QPs.

jaxopt.BoxOSQP([matvec_Q, matvec_A, fun, ...])

Operator Splitting Solver for Quadratic Programs.

jaxopt.CvxpyQP([solver, implicit_diff_solve])

Wraps CVXPY's quadratic solver with implicit diff support.

jaxopt.EqualityConstrainedQP([matvec_Q, ...])

Quadratic programming with equality constraints only.

jaxopt.OSQP(*[, matvec_Q, matvec_A, ...])

OSQP solver for general quadratic programming.

Non-smooth

jaxopt.ProximalGradient(fun[, prox, ...])

Proximal gradient solver.

jaxopt.BlockCoordinateDescent(fun, block_prox)

Block coordinate solver.

Stochastic

jaxopt.ArmijoSGD(fun[, value_and_grad, ...])

SGD with Armijo line search.

jaxopt.OptaxSolver(fun, opt[, ...])

Optax solver.

jaxopt.PolyakSGD(fun[, value_and_grad, ...])

SGD with Polyak step size.

Loss functions

jaxopt.loss.binary_logistic_loss(label, logit)

Binary logistic loss.

jaxopt.loss.binary_sparsemax_loss(label, logit)

Binary sparsemax loss.

jaxopt.loss.binary_hinge_loss(label, score)

Binary hinge loss.

jaxopt.loss.binary_perceptron_loss(label, score)

Binary perceptron loss.

jaxopt.loss.sparse_plus(x)

Sparse plus function.

jaxopt.loss.sparse_sigmoid(x)

Sparse sigmoid function.

jaxopt.loss.huber_loss(target, pred[, delta])

Huber loss.

jaxopt.loss.multiclass_logistic_loss(label, ...)

Multiclass logistic loss.

jaxopt.loss.multiclass_sparsemax_loss(label, ...)

Multiclass sparsemax loss.

jaxopt.loss.multiclass_hinge_loss(label, scores)

Multiclass hinge loss.

jaxopt.loss.multiclass_perceptron_loss(...)

Binary perceptron loss.

Linear system solving

jaxopt.linear_solve.solve_lu(matvec, b)

Solves A x = b using jax.lax.solve.

jaxopt.linear_solve.solve_cholesky(matvec, b)

Solves A x = b, using Cholesky decomposition.

jaxopt.linear_solve.solve_cg(matvec, b[, ...])

Solves A x = b using conjugate gradient.

jaxopt.linear_solve.solve_normal_cg(matvec, b)

Solves the normal equation A^T A x = A^T b using conjugate gradient.

jaxopt.linear_solve.solve_gmres(matvec, b[, ...])

Solves A x = b using gmres.

jaxopt.linear_solve.solve_bicgstab(matvec, b)

Solves A x = b using bicgstab.

jaxopt.IterativeRefinement([matvec_A, ...])

Iterativement refinement algorithm.

Nonlinear least squares

jaxopt.GaussNewton(residual_fun[, maxiter, ...])

Gauss-Newton nonlinear least-squares solver.

jaxopt.LevenbergMarquardt(residual_fun[, ...])

Levenberg-Marquardt nonlinear least-squares solver.

Root finding

jaxopt.Bisection(optimality_fun, lower, upper)

One-dimensional root finding using bisection.

jaxopt.Broyden(fun[, has_aux, maxiter, tol, ...])

Limited-memory Broyden solver.

jaxopt.ScipyRootFinding([method, dtype, ...])

scipy.optimize.root wrapper.

Fixed point resolution

jaxopt.FixedPointIteration(fixed_point_fun)

Fixed point iteration method.

jaxopt.AndersonAcceleration(fixed_point_fun)

Anderson acceleration.

jaxopt.AndersonWrapper(solver[, ...])

Wrapper for accelerating JAXopt solvers.

Implicit differentiation

jaxopt.implicit_diff.custom_root(optimality_fun)

Decorator for adding implicit differentiation to a root solver.

jaxopt.implicit_diff.custom_fixed_point(...)

Decorator for adding implicit differentiation to a fixed point solver.

jaxopt.implicit_diff.root_jvp(...[, solve])

Jacobian-vector product of a root.

jaxopt.implicit_diff.root_vjp(...[, solve])

Vector-Jacobian product of a root.

Perturbed optimizers

jaxopt.perturbations.make_perturbed_argmax(...)

Transforms a function into a differentiable version with perturbations.

jaxopt.perturbations.make_perturbed_max(...)

Turns an argmax in a differentiable version of the max with perturbations.

jaxopt.perturbations.make_perturbed_fun(fun)

Transforms a function into a differentiable version with perturbations.

jaxopt.perturbations.Gumbel()

Gumbel distribution.

jaxopt.perturbations.Normal()

Normal distribution.

Isotonic regression

jaxopt.isotonic.isotonic_l2_pav(y[, y_min, ...])

Solves an isotonic regression problem using PAV.

Tree utilities

jaxopt.tree_util.tree_add(tree, *rest[, is_leaf])

Tree addition.

jaxopt.tree_util.tree_sub(tree, *rest[, is_leaf])

Tree subtraction.

jaxopt.tree_util.tree_mul(tree, *rest[, is_leaf])

Tree multiplication.

jaxopt.tree_util.tree_div(tree, *rest[, is_leaf])

Tree division.

jaxopt.tree_util.tree_scalar_mul(scalar, tree_x)

Compute scalar * tree_x.

jaxopt.tree_util.tree_add_scalar_mul(tree_x, ...)

Compute tree_x + scalar * tree_y.

jaxopt.tree_util.tree_vdot(tree_x, tree_y)

Compute the inner product <tree_x, tree_y>.

jaxopt.tree_util.tree_sum(tree_x)

Compute sum(tree_x).

jaxopt.tree_util.tree_l2_norm(tree_x[, squared])

Compute the l2 norm ||tree_x||.

jaxopt.tree_util.tree_zeros_like(tree_x)

Creates an all-zero tree with the same structure as tree_x.