jaxopt.implicit_diff.custom_root
- jaxopt.implicit_diff.custom_root(optimality_fun, has_aux=False, solve=<function solve_normal_cg>, reference_signature=None)[source]
Decorator for adding implicit differentiation to a root solver.
- Parameters
optimality_fun (
Callable
) – an equation function,optimality_fun(params, *args)
. The invariant isoptimality_fun(sol, *args) == 0
at the solution / rootsol
.has_aux (
bool
) – whether the decorated solver function returns auxiliary data.solve (
Callable
) – a linear solver of the formsolve(matvec, b)
.reference_signature (
Optional
[Callable
]) – optional function signature (i.e. arguments and keyword arguments), with which the solver and optimality functions are expected to agree. Defaults tooptimality_fun
. It is required that solver and optimality functions share the same input signature, but both might be defined in such a way that the signature correspondence is ambiguous (e.g. if both accept catch-all**kwargs
). To satisfy custom_root’s requirement, any function with an unambiguous signature can be provided here.
- Returns
A solver function decorator, i.e.,
custom_root(optimality_fun)(solver_fun)
.