jaxopt.implicit_diff.root_vjp
- jaxopt.implicit_diff.root_vjp(optimality_fun, sol, args, cotangent, solve=<function solve_normal_cg>)[source]
Vector-Jacobian product of a root.
The invariant is
optimality_fun(sol, *args) == 0
.- Parameters
optimality_fun (
Callable
) – the optimality function to use.sol (
Any
) – solution / root (pytree).args (
Tuple
) – tuple containing the arguments with respect to which we wish to differentiatesol
against.cotangent (
Any
) – vector to left-multiply the Jacobian with (pytree, same structure assol
).solve (
Callable
) – a linear solver of the formx = solve(matvec, b)
, wherematvec(x) = Ax
andAx=b
.
- Return type
Any
- Returns
tuple of the same length as
len(args)
containing the vjps w.r.t. each argument. Eachvjps[i]
has the same pytree structure asargs[i]
.