jaxopt.implicit_diff.custom_fixed_point
- jaxopt.implicit_diff.custom_fixed_point(fixed_point_fun, has_aux=False, solve=<function solve_normal_cg>, reference_signature=None)[source]
Decorator for adding implicit differentiation to a fixed point solver.
- Parameters
fixed_point_fun (
Callable) – a function,fixed_point_fun(params, *args). The invariant isfixed_point_fun(sol, *args) == solat the solutionsol.has_aux (
bool) – whether the decorated solver function returns auxiliary data.solve (
Callable) – a linear solver of the formsolve(matvec, b).reference_signature (
Optional[Callable]) – optional function whose signature (i.e. arguments and keyword arguments) is one with which the solver and fixed-point functions are expected to agree. Defaults tofixed_point_fun. It is required that solver and fixed-point functions share the same input signature, but both might be defined in such a way that the signature correspondence is ambiguous (e.g. if both accept catch-all**kwargs). To satisfy custom_fixed_points’s requirement, any function with an unambiguous signature can be provided here.
- Returns
A solver function decorator, i.e.,
custom_fixed_point(fixed_point_fun)(solver_fun).