jaxopt.implicit_diff.root_jvp
- jaxopt.implicit_diff.root_jvp(optimality_fun, sol, args, tangents, solve=<function solve_normal_cg>)[source]
Jacobian-vector product of a root.
The invariant is
optimality_fun(sol, *args) == 0.- Parameters
optimality_fun (
Callable) – the optimality function to use.sol (
Any) – solution / root (pytree).args (
Tuple) – tuple containing the arguments with respect to which to differentiate.tangents (
Tuple) – a tuple of the same size aslen(args). Eachtangents[i]has the same pytree structure asargs[i].solve (
Callable) – a linear solver of the formsolve(matvec, b).
- Return type
Any- Returns
a pytree with the same structure as
sol.