jaxopt.BlockCoordinateDescent
- class jaxopt.BlockCoordinateDescent(fun, block_prox, maxiter=500, tol=0.0001, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')[source]
Block coordinate solver.
This solver minimizes:
objective(params, hyperparams_prox, *args, **kwargs) = fun(params, *args, **kwargs) + non_smooth(params, hyperparams_prox)
- Parameters
fun (CompositeLinearFunction) –
block_prox (Callable) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- fun
a smooth function of the form
fun(params, *args, **kwargs)
. It should be aobjective.CompositeLinearFunction
object.- Type
jaxopt._src.objective.CompositeLinearFunction
- block_prox
block-wise proximity operator associated with
non_smooth
, a function of the formblock_prox(x[j], hyperparams_prox, scaling=1.0)
. Seejaxopt.prox
for examples.- Type
Callable
- maxiter
maximum number of proximal gradient descent iterations.
- Type
int
- tol
tolerance to use.
- Type
float
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
int
- implicit_diff
whether to enable implicit diff or autodiff of unrolled iterations.
- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- jit
whether to JIT-compile the optimization loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the optimization loop (default: “auto”).
- Type
Union[str, bool]
- __init__(fun, block_prox, maxiter=500, tol=0.0001, verbose=0, implicit_diff=True, implicit_diff_solve=None, jit='auto', unroll='auto')
- Parameters
fun (CompositeLinearFunction) –
block_prox (Callable) –
maxiter (int) –
tol (float) –
verbose (int) –
implicit_diff (bool) –
implicit_diff_solve (Optional[Callable]) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(fun, block_prox[, maxiter, tol, ...])attribute_names
()attribute_values
()init_state
(init_params, hyperparams_prox, ...)Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
optimality_fun
(params, hyperparams_prox, ...)Proximal-gradient fixed point residual.
run
(init_params, *args, **kwargs)Runs the optimization loop.
update
(params, state, hyperparams_prox, ...)Performs one epoch of block CD.
Attributes
- init_state(init_params, hyperparams_prox, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.hyperparams_prox (
Any
) – pytree containing hyperparameters of block_prox.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
BlockCDState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- optimality_fun(params, hyperparams_prox, *args, **kwargs)[source]
Proximal-gradient fixed point residual.
This function is compatible with
@custom_root
.The fixed point function is defined as:
fixed_point_fun(params, hyperparams_prox, *args, **kwargs) = prox(params - grad(fun)(params, *args, **kwargs), hyperparams_prox)
where:
prox = jax.vmap(block_prox, in_axes=(0, None))
The residual is defined as:
optimality_fun(params, hyperparams_prox, *args, **kwargs) = fixed_point_fun(params, hyperparams_prox, *args, **kwargs) - params
- Parameters
params (
Any
) – pytree containing the parameters.hyperparams_prox (
Any
) – pytree containing hyperparameters of block_prox.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Returns
pytree with same structure as
params
.- Return type
residual
- run(init_params, *args, **kwargs)
Runs the optimization loop.
- Parameters
init_params (
Any
) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)
- update(params, state, hyperparams_prox, *args, **kwargs)[source]
Performs one epoch of block CD.
- Parameters
params (
Any
) – pytree containing the parameters.state (
NamedTuple
) – named tuple containing the solver state.hyperparams_prox (
Any
) – pytree containing hyperparameters of block_prox.*args – additional positional arguments to be passed to
fun
.**kwargs – additional keyword arguments to be passed to
fun
.
- Return type
OptStep
- Returns
(params, state)