jaxopt.linear_solve.solve_normal_cg

jaxopt.linear_solve.solve_normal_cg(matvec, b, ridge=None, init=None, **kwargs)[source]

Solves the normal equation A^T A x = A^T b using conjugate gradient.

This can be used to solve Ax=b using conjugate gradient when A is not hermitian, positive definite.

Parameters
  • matvec (Callable) – product between A and a vector.

  • b (Any) – pytree.

  • ridge (Optional[float]) – optional ridge regularization.

  • init (Optional[Any]) – optional initialization to be used by normal conjugate gradient.

  • **kwargs – additional keyword arguments for solver.

Return type

Any

Returns

pytree with same structure as b.