jaxopt.ScipyMinimize

class jaxopt.ScipyMinimize(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, fun=None, callback=None, tol=None, options=None, maxiter=500, value_and_grad=False)[source]

scipy.optimize.minimize wrapper

This wrapper is for unconstrained minimization only. It supports pytrees and implicit diff.

Parameters
  • method (Optional[str]) –

  • dtype (Optional[Any]) –

  • jit (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • fun (Callable) –

  • callback (Callable) –

  • tol (Optional[float]) –

  • options (Optional[Dict[str, Any]]) –

  • maxiter (int) –

  • value_and_grad (Union[bool, Callable]) –

fun

a smooth function of the form fun(x, *args, **kwargs).

Type

Callable

method

the method argument for scipy.optimize.minimize. Should be one of * ‘Nelder-Mead’ * ‘Powell’ * ‘CG’ * ‘BFGS’ * ‘Newton-CG’ * ‘L-BFGS-B’ * ‘TNC’ * ‘COBYLA’ * ‘SLSQP’ * ‘trust-constr’ * ‘dogleg’ * ‘trust-ncg’ * ‘trust-exact’ * ‘trust-krylov’

Type

Optional[str]

tol

the tol argument for scipy.optimize.minimize.

Type

Optional[float]

options

the options argument for scipy.optimize.minimize.

Type

Optional[Dict[str, Any]]

callback

called after each iteration, as callback(xk), where xk is the current parameter vector.

Type

Callable

dtype

if not None, cast all NumPy arrays to this dtype. Note that some methods relying on FORTRAN code, such as the L-BFGS-B solver for scipy.optimize.minimize, require casting to float64.

Type

Optional[Any]

jit

whether to JIT-compile JAX-based values and grad evals.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

has_aux

whether function fun outputs one (False) or more values (True). When True it will be assumed by default that fun(…)[0] is the objective.

Type

bool

value_and_grad

See base.make_funs_with_aux for more detail.

Type

Union[bool, Callable]

__init__(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, fun=None, callback=None, tol=None, options=None, maxiter=500, value_and_grad=False)
Parameters
  • method (Optional[str]) –

  • dtype (Optional[Any]) –

  • jit (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • fun (Optional[Callable]) –

  • callback (Optional[Callable]) –

  • tol (Optional[float]) –

  • options (Optional[Dict[str, Any]]) –

  • maxiter (int) –

  • value_and_grad (Union[bool, Callable]) –

Return type

None

Methods

__init__([method, dtype, jit, ...])

attribute_names()

attribute_values()

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(sol, *args, **kwargs)

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)

Runs the solver.

Attributes

callback

fun

has_aux

implicit_diff_solve

jit

maxiter

method

options

tol

value_and_grad

dtype

alias of float64

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(sol, *args, **kwargs)[source]

Optimality function mapping compatible with @custom_root.

run(init_params, *args, **kwargs)[source]

Runs the solver.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, info).