jaxopt.BlockCoordinateDescent

class jaxopt.BlockCoordinateDescent(fun, block_prox, maxiter=500, tol=0.0001, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]

Block coordinate solver.

This solver minimizes:

objective(params, hyperparams_prox, *args, **kwargs) =
  fun(params, *args, **kwargs) + non_smooth(params, hyperparams_prox)
Parameters
  • fun (CompositeLinearFunction) –

  • block_prox (Callable) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

fun

a smooth function of the form fun(params, *args, **kwargs). It should be a objective.CompositeLinearFunction object.

Type

jaxopt._src.objective.CompositeLinearFunction

block_prox

block-wise proximity operator associated with non_smooth, a function of the form block_prox(x[j], hyperparams_prox, scaling=1.0). See jaxopt.prox for examples.

Type

Callable

maxiter

maximum number of proximal gradient descent iterations.

Type

int

tol

tolerance to use.

Type

float

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

int

implicit_diff

whether to enable implicit diff or autodiff of unrolled iterations.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

jit

whether to JIT-compile the optimization loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the optimization loop (default: “auto”).

Type

Union[str, bool]

__init__(fun, block_prox, maxiter=500, tol=0.0001, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
Parameters
  • fun (CompositeLinearFunction) –

  • block_prox (Callable) –

  • maxiter (int) –

  • tol (float) –

  • verbose (int) –

  • implicit_diff (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(fun, block_prox[, maxiter, tol, ...])

attribute_names()

attribute_values()

init_state(init_params, hyperparams_prox, ...)

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, hyperparams_prox, ...)

Proximal-gradient fixed point residual.

run(init_params, *args, **kwargs)

Runs the optimization loop.

update(params, state, hyperparams_prox, ...)

Performs one epoch of block CD.

Attributes

implicit_diff

implicit_diff_solve

jit

maxiter

tol

unroll

verbose

fun

block_prox

init_state(init_params, hyperparams_prox, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • hyperparams_prox (Any) – pytree containing hyperparameters of block_prox.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

BlockCDState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

optimality_fun(params, hyperparams_prox, *args, **kwargs)[source]

Proximal-gradient fixed point residual.

This function is compatible with @custom_root.

The fixed point function is defined as:

fixed_point_fun(params, hyperparams_prox, *args, **kwargs) =
  prox(params - grad(fun)(params, *args, **kwargs), hyperparams_prox)

where:

prox = jax.vmap(block_prox, in_axes=(0, None))

The residual is defined as:

optimality_fun(params, hyperparams_prox, *args, **kwargs) =
  fixed_point_fun(params, hyperparams_prox, *args, **kwargs) - params
Parameters
  • params (Any) – pytree containing the parameters.

  • hyperparams_prox (Any) – pytree containing hyperparameters of block_prox.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Returns

pytree with same structure as params.

Return type

residual

run(init_params, *args, **kwargs)

Runs the optimization loop.

Parameters
  • init_params (Any) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, hyperparams_prox, *args, **kwargs)[source]

Performs one epoch of block CD.

Parameters
  • params (Any) – pytree containing the parameters.

  • state (NamedTuple) – named tuple containing the solver state.

  • hyperparams_prox (Any) – pytree containing hyperparameters of block_prox.

  • *args – additional positional arguments to be passed to fun.

  • **kwargs – additional keyword arguments to be passed to fun.

Return type

OptStep

Returns

(params, state)