jaxopt.IterativeRefinement
- class jaxopt.IterativeRefinement(matvec_A=None, matvec_A_bar=None, solve=functools.partial(<function solve_gmres>, ridge=1e-06), maxiter=10, tol=1e-07, verbose=0, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
Iterativement refinement algorithm.
This is a meta-algorithm for solving the linear system
Ax = b
based on a provided linear system solver. Our implementation is a slight generalization of the standard algorithm. It starts with \((r_0, x_0) = (b, 0)\) and iterates\[\begin{split}\begin{aligned} x &= \text{solution of } \bar{A} x = r_{t-1}\\ x_t &= x_{t-1} + x\\ r_t &= b - A x_t \end{aligned}\end{split}\]where \(\bar{A}\) is some approximation of A, with preferably better preconditonning than A. By default, we use \(\bar{A} = A\), which is the standard iterative refinement algorithm.
This method has the advantage of converging even if the solve step is inaccurate. This is particularly useful for ill-posed problems.
- Parameters
matvec_A (Optional[Callable]) –
matvec_A_bar (Optional[Callable]) –
solve (Callable) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- matvec_A
(optional) a Callable matvec_A(A, x). By default, matvec_A(A, x) = tree_dot(A, x), where pytree A matches x structure.
- Type
Optional[Callable]
- matvec_A_bar
(optional) a Callable. If None, then \(\bar{A}=A\). Otherwise, a Callable matvec_A_bar(x).
- Type
Optional[Callable]
- solve
a Callable that accepts A as first argument, b as second, and a warm start
init
as third argument. This solver can be inaccurate and run with low precision.- Type
Callable
- Parameters
matvec (Callable) –
b (Any) –
ridge (Optional[float]) –
init (Optional[Any]) –
tol (float) –
- Return type
Any
- maxiter
maximum number of iterations (default: 10).
- Type
int
- tol
absolute tolerance for stoping criterion (default: 1e-7).
- Type
float
- verbose
If verbose=1, print error at each iteration.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”)
- Type
Union[str, bool]
References
[1] J. H. Wilkinson. Rounding Errors in Algebraic Processes. Prentice Hall, Englewood Cliffs, NJ, 1963.
[2] Moler, C.B., 1967. Iterative refinement in floating point. Journal of the ACM (JACM), 14(2), pp.316-321.
[3] https://en.wikipedia.org/wiki/Iterative_refinement.
- __init__(matvec_A=None, matvec_A_bar=None, solve=functools.partial(<function solve_gmres>, ridge=1e-06), maxiter=10, tol=1e-07, verbose=0, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
matvec_A (Optional[Callable]) –
matvec_A_bar (Optional[Callable]) –
solve (Callable) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
([matvec_A, matvec_A_bar, solve, ...])attribute_names
()attribute_values
()init_params
(A, b[, A_bar])init_state
(init_params, A, b[, A_bar])l2_optimality_error
(params, A, b[, A_bar])Computes the L2 optimality error.
optimality_fun
(params, A, b[, A_bar])run
(init_params, A, b[, A_bar])Runs the iterative refinement.
solve
(b, *[, ridge, init, tol])Solves
A x = b
using gmres.update
(params, state, A, b[, A_bar])Attributes
- l2_optimality_error(params, A, b, A_bar=None)[source]
Computes the L2 optimality error.
- Parameters
params (Any) –
A (Any) –
b (Any) –
A_bar (Optional[Any]) –
- run(init_params, A, b, A_bar=None)[source]
Runs the iterative refinement.
- Parameters
init_params – init_params for warm start.
A (
Any
) – params forself.matvec_A
.b (
Any
) – vectorb
inAx=b
.A_bar (
Optional
[Any
]) – optional parameters formatvec_A_bar
.
- Returns
(params, state),
params = (primal_var, dual_var_eq, dual_var_ineq)
- solve(b, *, ridge=1e-06, init=None, tol=1e-05, **kwargs)
Solves
A x = b
using gmres.- Parameters
matvec (Callable) – product between
A
and a vector.b (Any) – pytree.
ridge (Optional[float]) – optional ridge regularization.
init (Optional[Any]) – optional initialization to be used by gmres.
**kwargs – additional keyword arguments for solver.
tol (float) –
- Returns
pytree with same structure as
b
.- Return type
Any