jaxopt.ScipyRootFinding

class jaxopt.ScipyRootFinding(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, optimality_fun=None, tol=None, options=None, use_jacrev=True)[source]

scipy.optimize.root wrapper.

It supports pytrees and implicit diff.

Parameters
  • method (Optional[str]) –

  • dtype (Optional[Any]) –

  • jit (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • optimality_fun (Callable) –

  • tol (Optional[float]) –

  • options (Optional[Dict[str, Any]]) –

  • use_jacrev (bool) –

optimality_fun

a smooth vector function of the form optimality_fun(x, *args, **kwargs) whose root is to be found. It must return as output a PyTree with structure identical to x.

Type

Callable

method

the method argument for scipy.optimize.root. Should be one of

  • ‘hybr’

  • ‘lm’

  • ‘broyden1’

  • ‘broyden2’

  • ‘anderson’

  • ‘linearmixing’

  • ‘diagbroyden’

  • ‘excitingmixing’

  • ‘krylov’

  • ‘df-sane’

Type

Optional[str]

tol

the tol argument for scipy.optimize.root.

Type

Optional[float]

options

the options argument for scipy.optimize.root.

Type

Optional[Dict[str, Any]]

dtype

if not None, cast all NumPy arrays to this dtype. Note that some methods relying on FORTRAN code, such as the L-BFGS-B solver for scipy.optimize.minimize, require casting to float64.

Type

Optional[Any]

jit

whether to JIT-compile JAX-based values and grad evals.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

has_aux

whether function fun outputs one (False) or more values (True). When True it will be assumed by default that optimality_fun(…)[0] is the optimality function.

Type

bool

use_jacrev

whether to compute the Jacobian of optimality_fun using jax.jacrev (True) or jax.jacfwd (False).

Type

bool

__init__(method=None, dtype=<class 'numpy.float64'>, jit=True, implicit_diff_solve=None, has_aux=False, optimality_fun=None, tol=None, options=None, use_jacrev=True)
Parameters
  • method (Optional[str]) –

  • dtype (Optional[Any]) –

  • jit (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • optimality_fun (Optional[Callable]) –

  • tol (Optional[float]) –

  • options (Optional[Dict[str, Any]]) –

  • use_jacrev (bool) –

Return type

None

Methods

__init__([method, dtype, jit, ...])

attribute_names()

attribute_values()

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_params, *args, **kwargs)

Runs the solver.

Attributes

has_aux

implicit_diff_solve

jit

method

optimality_fun

options

tol

use_jacrev

dtype

alias of float64

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_params, *args, **kwargs)[source]

Runs the solver.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, info).