jaxopt.IterativeRefinement

class jaxopt.IterativeRefinement(matvec_A=None, matvec_A_bar=None, solve=functools.partial(<function solve_gmres>, ridge=1e-06), maxiter=10, tol=1e-07, verbose=0, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Iterativement refinement algorithm.

This is a meta-algorithm for solving the linear system Ax = b based on a provided linear system solver. Our implementation is a slight generalization of the standard algorithm. It starts with \((r_0, x_0) = (b, 0)\) and iterates

\[\begin{split}\begin{aligned} x &= \text{solution of } \bar{A} x = r_{t-1}\\ x_t &= x_{t-1} + x\\ r_t &= b - A x_t \end{aligned}\end{split}\]

where \(\bar{A}\) is some approximation of A, with preferably better preconditonning than A. By default, we use \(\bar{A} = A\), which is the standard iterative refinement algorithm.

This method has the advantage of converging even if the solve step is inaccurate. This is particularly useful for ill-posed problems.

Parameters
  • matvec_A (Optional[Callable]) –

  • matvec_A_bar (Optional[Callable]) –

  • solve (Callable) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

matvec_A

(optional) a Callable matvec_A(A, x). By default, matvec_A(A, x) = tree_dot(A, x), where pytree A matches x structure.

Type

Optional[Callable]

matvec_A_bar

(optional) a Callable. If None, then \(\bar{A}=A\). Otherwise, a Callable matvec_A_bar(x).

Type

Optional[Callable]

solve

a Callable that accepts A as first argument, b as second, and a warm start init as third argument. This solver can be inaccurate and run with low precision.

Type

Callable

Parameters
  • matvec (Callable) –

  • b (Any) –

  • ridge (Optional[float]) –

  • init (Optional[Any]) –

  • tol (float) –

Return type

Any

maxiter

maximum number of iterations (default: 10).

Type

int

tol

absolute tolerance for stoping criterion (default: 1e-7).

Type

float

verbose

If verbose=1, print error at each iteration.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”)

Type

Union[str, bool]

References

[1] J. H. Wilkinson. Rounding Errors in Algebraic Processes. Prentice Hall, Englewood Cliffs, NJ, 1963.

[2] Moler, C.B., 1967. Iterative refinement in floating point. Journal of the ACM (JACM), 14(2), pp.316-321.

[3] https://en.wikipedia.org/wiki/Iterative_refinement.

__init__(matvec_A=None, matvec_A_bar=None, solve=functools.partial(<function solve_gmres>, ridge=1e-06), maxiter=10, tol=1e-07, verbose=0, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • matvec_A (Optional[Callable]) –

  • matvec_A_bar (Optional[Callable]) –

  • solve (Callable) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__([matvec_A, matvec_A_bar, solve, ...])

attribute_names()

attribute_values()

init_params(A, b[, A_bar])

init_state(init_params, A, b[, A_bar])

l2_optimality_error(params, A, b[, A_bar])

Computes the L2 optimality error.

optimality_fun(params, A, b[, A_bar])

run(init_params, A, b[, A_bar])

Runs the iterative refinement.

solve(b, *[, ridge, init, tol])

Solves A x = b using gmres.

update(params, state, A, b[, A_bar])

Attributes

implicit_diff_solve

jit

matvec_A

matvec_A_bar

maxiter

tol

unroll

verbose

l2_optimality_error(params, A, b, A_bar=None)[source]

Computes the L2 optimality error.

Parameters
  • params (Any) –

  • A (Any) –

  • b (Any) –

  • A_bar (Optional[Any]) –

run(init_params, A, b, A_bar=None)[source]

Runs the iterative refinement.

Parameters
  • init_params – init_params for warm start.

  • A (Any) – params for self.matvec_A.

  • b (Any) – vector b in Ax=b.

  • A_bar (Optional[Any]) – optional parameters for matvec_A_bar.

Returns

(params, state), params = (primal_var, dual_var_eq, dual_var_ineq)

solve(b, *, ridge=1e-06, init=None, tol=1e-05, **kwargs)

Solves A x = b using gmres.

Parameters
  • matvec (Callable) – product between A and a vector.

  • b (Any) – pytree.

  • ridge (Optional[float]) – optional ridge regularization.

  • init (Optional[Any]) – optional initialization to be used by gmres.

  • **kwargs – additional keyword arguments for solver.

  • tol (float) –

Returns

pytree with same structure as b.

Return type

Any