jaxopt.ScipyMinimize
- class jaxopt.ScipyMinimize(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, fun=None, callback=None, tol=None, options=None, maxiter=500, value_and_grad=False)[source]
scipy.optimize.minimize wrapper
This wrapper is for unconstrained minimization only. It supports pytrees and implicit diff.
- Parameters
method (Optional[str]) –
dtype (Optional[Any]) –
jit (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
fun (Callable) –
callback (Callable) –
tol (Optional[float]) –
options (Optional[Dict[str, Any]]) –
maxiter (int) –
value_and_grad (Union[bool, Callable]) –
- fun
a smooth function of the form fun(x, *args, **kwargs).
- Type
Callable
- method
the method argument for scipy.optimize.minimize. Should be one of * ‘Nelder-Mead’ * ‘Powell’ * ‘CG’ * ‘BFGS’ * ‘Newton-CG’ * ‘L-BFGS-B’ * ‘TNC’ * ‘COBYLA’ * ‘SLSQP’ * ‘trust-constr’ * ‘dogleg’ * ‘trust-ncg’ * ‘trust-exact’ * ‘trust-krylov’
- Type
Optional[str]
- tol
the tol argument for scipy.optimize.minimize.
- Type
Optional[float]
- options
the options argument for scipy.optimize.minimize.
- Type
Optional[Dict[str, Any]]
- callback
called after each iteration, as callback(xk), where xk is the current parameter vector.
- Type
Callable
- dtype
if not None, cast all NumPy arrays to this dtype. Note that some methods relying on FORTRAN code, such as the L-BFGS-B solver for scipy.optimize.minimize, require casting to float64.
- Type
Optional[Any]
- jit
whether to JIT-compile JAX-based values and grad evals.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- has_aux
whether function fun outputs one (False) or more values (True). When True it will be assumed by default that fun(…)[0] is the objective.
- Type
bool
- value_and_grad
See base.make_funs_with_aux for more detail.
- Type
Union[bool, Callable]
- __init__(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, fun=None, callback=None, tol=None, options=None, maxiter=500, value_and_grad=False)
- Parameters
method (Optional[str]) –
dtype (Optional[Any]) –
jit (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
fun (Optional[Callable]) –
callback (Optional[Callable]) –
tol (Optional[float]) –
options (Optional[Dict[str, Any]]) –
maxiter (int) –
value_and_grad (Union[bool, Callable]) –
- Return type
None
Methods
__init__
([method, dtype, jit, ...])attribute_names
()attribute_values
()l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(sol, *args, **kwargs)Optimality function mapping compatible with @custom_root.
run
(init_params, *args, **kwargs)Runs the solver.
Attributes
maxiter
- dtype
alias of
float64
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.