Linear system solving
This section is concerned with solving problems of the form
with unknown \(x\) for a linear operator \(A\) and vector \(b\).
Indirect solvers
|
Solves |
|
Solves the normal equation |
|
Solves |
|
Solves |
Indirect solvers iteratively solve the linear system up to some precision. Example:
from jaxopt import linear_solve
import numpy as onp
onp.random.seed(42)
A = onp.random.randn(3, 3)
b = onp.random.randn(3)
def matvec_A(x):
return jnp.dot(A, x)
sol = linear_solve.solve_normal_cg(matvec_A, b, tol=1e-5)
print(sol)
sol = linear_solve.solve_gmres(matvec_A, b, tol=1e-5)
print(sol)
sol = linear_solve.solve_bicgstab(matvec_A, b, tol=1e-5)
print(sol)
The above solvers support ridge regularization with the ridge
option.
They can be warm-started using the init
option.
Other options, such as tol
or maxiter
, are also supported.
Direct solvers
|
Solves |
|
Solves |
Direct solvers are based on matrix decompositions. They need to materialize the matrix in memory.
Example:
from jaxopt import linear_solve
import numpy as onp
onp.random.seed(42)
A = onp.random.randn(3, 3)
b = onp.random.randn(3)
def matvec_A(x):
return jnp.dot(A, x)
sol = linear_solve.solve_lu(matvec_A, b)
print(sol)
Iterative refinement
|
Iterativement refinement algorithm. |
Iterative refinement
is a meta-algorithm for solving the linear system Ax = b
based on
a provided linear system solver. Our implementation is a slight generalization
of the standard algorithm. It starts with \((r_0, x_0) = (b, 0)\) and
iterates
where \(\bar{A}\) is some approximation of A, with preferably better preconditonning than A. By default, we use \(\bar{A} = A\), which is the standard iterative refinement algorithm. This method has the advantage of converging even if the solve step is inaccurate. This is particularly useful for ill-posed problems. Example:
from functools import partial
import jax.numpy as jnp
import numpy as onp
from jaxopt import IterativeRefinement
from jaxopt.linear_solve import solve_gmres
# ill-conditioned linear system
A = jnp.array([[3.9, 1.65], [6.845, 2.9]])
b = jnp.array([5.5, 9.7])
print(f"Condition number: {onp.linalg.cond(A):.0f}")
# Condition number: 4647
ridge = 1e-2
tol = 1e-7
x = solve_gmres(lambda x: jnp.dot(A, x), b, tol=tol)
print(f"GMRES only error: {jnp.linalg.norm(A @ x - b):.7f}")
# GMRES only error: nan
solve_gmres_ridge = partial(solve_gmres, ridge=ridge)
x_ridge = solve_gmres_ridge(lambda x: jnp.dot(A, x), b, tol=tol, ridge=ridge)
print(f"GMRES+ridge error: {jnp.linalg.norm(A @ x_ridge - b):.7f}")
# GMRES+ridge error: 0.0333328
solver = IterativeRefinement(solve=solve_gmres_ridge,
tol=tol, maxiter=100)
x_refined, state = solver.run(init_params=None, params_A=A, b=b)
print(f"Iterativement Refinement error: {jnp.linalg.norm(A @ x_refined - b):.7f}")
# Iterativement Refinement error: 0.0000000