jaxopt.GaussNewton

class jaxopt.GaussNewton(residual_fun, maxiter=30, tol=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]

Gauss-Newton nonlinear least-squares solver.

Given the residual function f(x): R^n -> R^m, where f(x) = residual_fun(x, *args, **kwargs), GaussNewton finds a local minimum of the cost function argmin_x 0.5 * sum(f(x) ** 2).

Parameters
  • residual_fun (Callable) –

  • maxiter (int) –

  • tol (float) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

residual_fun

a smooth function of the form residual_fun(x, *args, **kwargs).

Type

Callable

maxiter

maximum number of iterations.

Type

int

tol

tolerance.

Type

float

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

jit

whether to JIT-compile the bisection loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the bisection loop (default: “auto”).

Type

Union[str, bool]

__init__(residual_fun, maxiter=30, tol=1e-05, verbose=False, implicit_diff=True, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
Parameters
  • residual_fun (Callable) –

  • maxiter (int) –

  • tol (float) –

  • verbose (bool) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(residual_fun[, maxiter, tol, ...])

attribute_names()

attribute_values()

init_state(init_params, *args, **kwargs)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of the least-squares solver.

Attributes

has_aux

implicit_diff

implicit_diff_solve

jit

maxiter

tol

unroll

verbose

residual_fun

init_state(init_params, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to residual_fun.

  • **kwargs – additional keyword arguments to be passed to residual_fun.

Return type

GaussNewtonState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of the least-squares solver.

Parameters
  • params – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

Return type

OptStep

Returns

(params, state)