jaxopt.implicit_diff.root_vjp
- jaxopt.implicit_diff.root_vjp(optimality_fun, sol, args, cotangent, solve=<function solve_normal_cg>)[source]
Vector-Jacobian product of a root.
The invariant is
optimality_fun(sol, *args) == 0.- Parameters
optimality_fun (
Callable) – the optimality function to use.sol (
Any) – solution / root (pytree).args (
Tuple) – tuple containing the arguments with respect to which we wish to differentiatesolagainst.cotangent (
Any) – vector to left-multiply the Jacobian with (pytree, same structure assol).solve (
Callable) – a linear solver of the formx = solve(matvec, b), wherematvec(x) = AxandAx=b.
- Return type
Any- Returns
tuple of the same length as
len(args)containing the vjps w.r.t. each argument. Eachvjps[i]has the same pytree structure asargs[i].