jaxopt.implicit_diff.root_jvp
- jaxopt.implicit_diff.root_jvp(optimality_fun, sol, args, tangents, solve=<function solve_normal_cg>)[source]
Jacobian-vector product of a root.
The invariant is
optimality_fun(sol, *args) == 0
.- Parameters
optimality_fun (
Callable
) – the optimality function to use.sol (
Any
) – solution / root (pytree).args (
Tuple
) – tuple containing the arguments with respect to which to differentiate.tangents (
Tuple
) – a tuple of the same size aslen(args)
. Eachtangents[i]
has the same pytree structure asargs[i]
.solve (
Callable
) – a linear solver of the formsolve(matvec, b)
.
- Return type
Any
- Returns
a pytree with the same structure as
sol
.