Source code for jaxopt._src.block_cd

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of block coordinate descent in JAX."""

import inspect

from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional
from typing import Union

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src import loop
from jaxopt._src import objective
from jaxopt._src import tree_util


class BlockCDState(NamedTuple):
  """Named tuple containing state information."""
  iter_num: int
  error: float
  predictions: jnp.ndarray
  subfun_g: jnp.ndarray

  num_fun_eval: int = 0
  num_prox_eval: int = 0
  num_grad_eval: int = 0


[docs]@dataclass(eq=False) class BlockCoordinateDescent(base.IterativeSolver): """Block coordinate solver. This solver minimizes:: objective(params, hyperparams_prox, *args, **kwargs) = fun(params, *args, **kwargs) + non_smooth(params, hyperparams_prox) Attributes: fun: a smooth function of the form ``fun(params, *args, **kwargs)``. It should be a ``objective.CompositeLinearFunction`` object. block_prox: block-wise proximity operator associated with ``non_smooth``, a function of the form ``block_prox(x[j], hyperparams_prox, scaling=1.0)``. See ``jaxopt.prox`` for examples. maxiter: maximum number of proximal gradient descent iterations. tol: tolerance to use. verbose: whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. jit: whether to JIT-compile the optimization loop (default: "auto"). unroll: whether to unroll the optimization loop (default: "auto"). """ fun: objective.CompositeLinearFunction block_prox: Callable maxiter: int = 500 tol: float = 1e-4 verbose: int = 0 implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None jit: base.AutoOrBoolean = "auto" unroll: base.AutoOrBoolean = "auto"
[docs] def init_state(self, init_params: Any, hyperparams_prox: Any, *args, **kwargs) -> BlockCDState: """Initialize the solver state. Args: init_params: pytree containing the initial parameters. hyperparams_prox: pytree containing hyperparameters of block_prox. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Returns: state """ del hyperparams_prox # Not used. linop = self.fun.make_linop(*args, **kwargs) predictions = linop.matvec(init_params) subfun_g = self._grad_subfun(predictions, *args, **kwargs) return BlockCDState(iter_num=jnp.asarray(0), predictions=predictions, subfun_g=subfun_g, error=jnp.asarray(jnp.inf), num_fun_eval=jnp.array(1, base.NUM_EVAL_DTYPE), num_grad_eval=jnp.array(1, base.NUM_EVAL_DTYPE), num_prox_eval=jnp.array(0, base.NUM_EVAL_DTYPE) )
[docs] def update(self, params: Any, state: NamedTuple, hyperparams_prox: Any, *args, **kwargs) -> base.OptStep: """Performs one epoch of block CD. Args: params: pytree containing the parameters. state: named tuple containing the solver state. hyperparams_prox: pytree containing hyperparameters of block_prox. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Returns: (params, state) """ linop = self.fun.make_linop(*args, **kwargs) stepsizes = 1.0 / self.fun.columnwise_lipschitz_const(*args, **kwargs) # todo: ability to permute block order. def body_fun(i, tup): x, subfun_g, predictions, sqerror_sum = tup x_i_old = x[i] g_i = linop.rmatvec_element(subfun_g, i) b = self.fun.b(*args, **kwargs) if b is not None: g_i += b[i] x_i_new = self.block_prox(x[i] - stepsizes[i] * g_i, hyperparams_prox, stepsizes[i]) diff_i = x_i_new - x_i_old # A cheap-to-compute lower-bound of self.l2_optimality_error. sqerror_sum += jnp.sum(diff_i ** 2) x = x.at[i].set(x_i_new) predictions = linop.update_matvec(predictions, diff_i, i) subfun_g = self._grad_subfun(predictions, *args, **kwargs) return x, subfun_g, predictions, sqerror_sum init = (params, state.subfun_g, state.predictions, 0) n_for = params.shape[0] # FIXME: use a function similar to cond in order to have # a for loop that can be potentially non-jitted. # this will allow to unit test the number of function eval. # (zramzi) params, subfun_g, predictions, sqerror_sum = jax.lax.fori_loop( lower=0, upper=n_for, body_fun=body_fun, init_val=init) state = BlockCDState(iter_num=state.iter_num + 1, predictions=predictions, subfun_g=subfun_g, error=jnp.sqrt(sqerror_sum), num_fun_eval=state.num_fun_eval + n_for, num_grad_eval=state.num_grad_eval + n_for, num_prox_eval=state.num_prox_eval + n_for) return base.OptStep(params=params, state=state)
def _fixed_point_fun(self, params, hyperparams_prox, *args, **kwargs): grad_step = params - self._grad_fun(params, *args, **kwargs) return self._prox(grad_step, hyperparams_prox)
[docs] def optimality_fun(self, params: Any, hyperparams_prox: Any, *args, **kwargs) -> Any: """Proximal-gradient fixed point residual. This function is compatible with ``@custom_root``. The fixed point function is defined as:: fixed_point_fun(params, hyperparams_prox, *args, **kwargs) = prox(params - grad(fun)(params, *args, **kwargs), hyperparams_prox) where:: prox = jax.vmap(block_prox, in_axes=(0, None)) The residual is defined as:: optimality_fun(params, hyperparams_prox, *args, **kwargs) = fixed_point_fun(params, hyperparams_prox, *args, **kwargs) - params Args: params: pytree containing the parameters. hyperparams_prox: pytree containing hyperparameters of block_prox. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. Returns: residual: pytree with same structure as ``params``. """ fp = self._fixed_point_fun(params, hyperparams_prox, *args, **kwargs) return fp - params
def __post_init__(self): if not isinstance(self.fun, objective.CompositeLinearFunction): raise AttributeError("fun should be an instance of " "objective.CompositeLinearFunction.") # Pre-compile useful functions. self._grad_fun = jax.grad(self.fun) self._grad_subfun = jax.grad(self.fun.subfun) self._prox = jax.vmap(self.block_prox, in_axes=(0, None)) # Sets up reference signature. signature = inspect.signature(self.fun.subfun) parameters = list(signature.parameters.values()) new_param = inspect.Parameter(name="hyperparams_prox", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) parameters.insert(1, new_param) self.reference_signature = inspect.Signature(parameters)