jaxopt.ScipyRootFinding
- class jaxopt.ScipyRootFinding(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, optimality_fun=None, tol=None, options=None, use_jacrev=True)[source]
scipy.optimize.root wrapper.
It supports pytrees and implicit diff.
- Parameters
method (Optional[str]) –
dtype (Optional[Any]) –
jit (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
optimality_fun (Callable) –
tol (Optional[float]) –
options (Optional[Dict[str, Any]]) –
use_jacrev (bool) –
- optimality_fun
a smooth vector function of the form optimality_fun(x, *args, **kwargs) whose root is to be found. It must return as output a PyTree with structure identical to x.
- Type
Callable
- method
the method argument for scipy.optimize.root. Should be one of
‘hybr’
‘lm’
‘broyden1’
‘broyden2’
‘anderson’
‘linearmixing’
‘diagbroyden’
‘excitingmixing’
‘krylov’
‘df-sane’
- Type
Optional[str]
- tol
the tol argument for scipy.optimize.root.
- Type
Optional[float]
- options
the options argument for scipy.optimize.root.
- Type
Optional[Dict[str, Any]]
- dtype
if not None, cast all NumPy arrays to this dtype. Note that some methods relying on FORTRAN code, such as the L-BFGS-B solver for scipy.optimize.minimize, require casting to float64.
- Type
Optional[Any]
- jit
whether to JIT-compile JAX-based values and grad evals.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- has_aux
whether function fun outputs one (False) or more values (True). When True it will be assumed by default that optimality_fun(…)[0] is the optimality function.
- Type
bool
- use_jacrev
whether to compute the Jacobian of optimality_fun using jax.jacrev (True) or jax.jacfwd (False).
- Type
bool
- __init__(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, optimality_fun=None, tol=None, options=None, use_jacrev=True)
- Parameters
method (Optional[str]) –
dtype (Optional[Any]) –
jit (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
optimality_fun (Optional[Callable]) –
tol (Optional[float]) –
options (Optional[Dict[str, Any]]) –
use_jacrev (bool) –
- Return type
None
Methods
__init__
([method, dtype, jit, ...])attribute_names
()attribute_values
()l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
run
(init_params, *args, **kwargs)Runs the solver.
Attributes
- dtype
alias of
float64
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.