jaxopt.implicit_diff.root_jvp

jaxopt.implicit_diff.root_jvp(optimality_fun, sol, args, tangents, solve=<function solve_normal_cg>)[source]

Jacobian-vector product of a root.

The invariant is optimality_fun(sol, *args) == 0.

Parameters
  • optimality_fun (Callable) – the optimality function to use.

  • sol (Any) – solution / root (pytree).

  • args (Tuple) – tuple containing the arguments with respect to which to differentiate.

  • tangents (Tuple) – a tuple of the same size as len(args). Each tangents[i] has the same pytree structure as args[i].

  • solve (Callable) – a linear solver of the form solve(matvec, b).

Return type

Any

Returns

a pytree with the same structure as sol.