jaxopt.implicit_diff.custom_root

jaxopt.implicit_diff.custom_root(optimality_fun, has_aux=False, solve=<function solve_normal_cg>, reference_signature=None)[source]

Decorator for adding implicit differentiation to a root solver.

Parameters
  • optimality_fun (Callable) – an equation function, optimality_fun(params, *args). The invariant is optimality_fun(sol, *args) == 0 at the solution / root sol.

  • has_aux (bool) – whether the decorated solver function returns auxiliary data.

  • solve (Callable) – a linear solver of the form solve(matvec, b).

  • reference_signature (Optional[Callable]) – optional function signature (i.e. arguments and keyword arguments), with which the solver and optimality functions are expected to agree. Defaults to optimality_fun. It is required that solver and optimality functions share the same input signature, but both might be defined in such a way that the signature correspondence is ambiguous (e.g. if both accept catch-all **kwargs). To satisfy custom_root’s requirement, any function with an unambiguous signature can be provided here.

Returns

A solver function decorator, i.e., custom_root(optimality_fun)(solver_fun).