Linear system solving

This section is concerned with solving problems of the form

\[Ax = b\]

with unknown \(x\) for a linear operator \(A\) and vector \(b\).

Indirect solvers

jaxopt.linear_solve.solve_cg(matvec, b[, ...])

Solves A x = b using conjugate gradient.

jaxopt.linear_solve.solve_normal_cg(matvec, b)

Solves the normal equation A^T A x = A^T b using conjugate gradient.

jaxopt.linear_solve.solve_gmres(matvec, b[, ...])

Solves A x = b using gmres.

jaxopt.linear_solve.solve_bicgstab(matvec, b)

Solves A x = b using bicgstab.

Indirect solvers iteratively solve the linear system up to some precision. Example:

from jaxopt import linear_solve
import numpy as onp

onp.random.seed(42)
A = onp.random.randn(3, 3)
b = onp.random.randn(3)

def matvec_A(x):
  return  jnp.dot(A, x)

sol = linear_solve.solve_normal_cg(matvec_A, b, tol=1e-5)
print(sol)

sol = linear_solve.solve_gmres(matvec_A, b, tol=1e-5)
print(sol)

sol = linear_solve.solve_bicgstab(matvec_A, b, tol=1e-5)
print(sol)

The above solvers support ridge regularization with the ridge option. They can be warm-started using the init option. Other options, such as tol or maxiter, are also supported.

Direct solvers

jaxopt.linear_solve.solve_lu(matvec, b)

Solves A x = b using jax.lax.solve.

jaxopt.linear_solve.solve_cholesky(matvec, b)

Solves A x = b, using Cholesky decomposition.

Direct solvers are based on matrix decompositions. They need to materialize the matrix in memory.

Example:

from jaxopt import linear_solve
import numpy as onp

onp.random.seed(42)
A = onp.random.randn(3, 3)
b = onp.random.randn(3)

def matvec_A(x):
  return jnp.dot(A, x)

sol = linear_solve.solve_lu(matvec_A, b)
print(sol)

Iterative refinement

jaxopt.IterativeRefinement([matvec_A, ...])

Iterativement refinement algorithm.

Iterative refinement is a meta-algorithm for solving the linear system Ax = b based on a provided linear system solver. Our implementation is a slight generalization of the standard algorithm. It starts with \((r_0, x_0) = (b, 0)\) and iterates

\[\begin{split}\begin{aligned} x &= \text{solution of } \bar{A} x = r_{t-1}\\ x_t &= x_{t-1} + x\\ r_t &= b - A x_t \end{aligned}\end{split}\]

where \(\bar{A}\) is some approximation of A, with preferably better preconditonning than A. By default, we use \(\bar{A} = A\), which is the standard iterative refinement algorithm. This method has the advantage of converging even if the solve step is inaccurate. This is particularly useful for ill-posed problems. Example:

from functools import partial
import jax.numpy as jnp
import numpy as onp
from jaxopt import IterativeRefinement
from jaxopt.linear_solve import solve_gmres

# ill-conditioned linear system
A = jnp.array([[3.9, 1.65], [6.845, 2.9]])
b = jnp.array([5.5, 9.7])
print(f"Condition number: {onp.linalg.cond(A):.0f}")
# Condition number: 4647

ridge = 1e-2
tol = 1e-7

x = solve_gmres(lambda x: jnp.dot(A, x), b, tol=tol)
print(f"GMRES only error: {jnp.linalg.norm(A @ x - b):.7f}")
# GMRES only error: nan

solve_gmres_ridge = partial(solve_gmres, ridge=ridge)

x_ridge = solve_gmres_ridge(lambda x: jnp.dot(A, x), b, tol=tol, ridge=ridge)
print(f"GMRES+ridge error: {jnp.linalg.norm(A @ x_ridge - b):.7f}")
# GMRES+ridge error: 0.0333328

solver = IterativeRefinement(solve=solve_gmres_ridge,
                            tol=tol, maxiter=100)
x_refined, state = solver.run(init_params=None, params_A=A, b=b)
print(f"Iterativement Refinement error: {jnp.linalg.norm(A @ x_refined - b):.7f}")
# Iterativement Refinement error: 0.0000000