jaxopt.implicit_diff.root_vjp

jaxopt.implicit_diff.root_vjp(optimality_fun, sol, args, cotangent, solve=<function solve_normal_cg>)[source]

Vector-Jacobian product of a root.

The invariant is optimality_fun(sol, *args) == 0.

Parameters
  • optimality_fun (Callable) – the optimality function to use.

  • sol (Any) – solution / root (pytree).

  • args (Tuple) – tuple containing the arguments with respect to which we wish to differentiate sol against.

  • cotangent (Any) – vector to left-multiply the Jacobian with (pytree, same structure as sol).

  • solve (Callable) – a linear solver of the form x = solve(matvec, b), where matvec(x) = Ax and Ax=b.

Return type

Any

Returns

tuple of the same length as len(args) containing the vjps w.r.t. each argument. Each vjps[i] has the same pytree structure as args[i].