Source code for jaxopt._src.backtracking_linesearch

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Bactracking line search algorithm."""

from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt.tree_util import tree_add_scalar_mul
from jaxopt.tree_util import tree_scalar_mul
from jaxopt.tree_util import tree_vdot_real
from jaxopt.tree_util import tree_conj

class BacktrackingLineSearchState(NamedTuple):
  """Named tuple containing state information."""
  iter_num: int
  value: float
  error: float
  done: bool
  params: Any
  grad: Any
  num_fun_eval: int
  num_grad_eval: int
  failed: bool
  aux: Optional[Any] = None


[docs]@dataclass(eq=False) class BacktrackingLineSearch(base.IterativeLineSearch): """Backtracking line search. Supports complex variables. Attributes: fun: a function of the form ``fun(params, *args, **kwargs)``, where ``params`` are parameters of the model, ``*args`` and ``**kwargs`` are additional arguments. value_and_grad: if ``False``, ``fun`` should return the function value only. If ``True``, ``fun`` should return both the function value and the gradient. has_aux: if ``False``, ``fun`` should return the function value only. If ``True``, ``fun`` should return a pair ``(value, aux)`` where ``aux`` is a pytree of auxiliary values. condition: either "armijo", "goldstein", "strong-wolfe" or "wolfe". c1: constant used by the (strong) Wolfe condition. c2: constant strictly less than 1 used by the (strong) Wolfe condition. decrease_factor: factor by which to decrease the stepsize during line search (default: 0.8). max_stepsize: upper bound on stepsize. maxiter: maximum number of line search iterations. tol: tolerance of the stopping criterion. verbose: whether to print error on every iteration or not. verbose=True will automatically disable jit. jit: whether to JIT-compile the optimization loop (default: "auto"). unroll: whether to unroll the optimization loop (default: "auto"). """ fun: Callable value_and_grad: bool = False has_aux: bool = False maxiter: int = 30 tol: float = 0.0 condition: str = "strong-wolfe" c1: float = 1e-4 c2: float = 0.9 decrease_factor: float = 0.8 max_stepsize: float = 1.0 verbose: int = 0 jit: base.AutoOrBoolean = "auto" unroll: base.AutoOrBoolean = "auto"
[docs] def init_state( self, init_stepsize: float, params: Any, value: Optional[float] = None, grad: Optional[Any] = None, descent_direction: Optional[Any] = None, fun_args: list = [], fun_kwargs: dict = {}, ) -> BacktrackingLineSearchState: """Initialize the line search state. Args: init_stepsize: initial step size value. params: current parameters. value: current function value (recomputed if None). grad: current gradient (recomputed if None). descent_direction: ignored. fun_args: additional positional arguments to be passed to ``fun``. fun_kwargs: additional keyword arguments to be passed to ``fun``. Returns: state """ del descent_direction # Not used. num_fun_eval = 0 num_grad_eval = 0 if value is None or grad is None: if self.has_aux: (value, _), grad = self._value_and_grad_fun( params, *fun_args, **fun_kwargs ) else: value, grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs) num_fun_eval += 1 num_grad_eval += 1 return BacktrackingLineSearchState(iter_num=jnp.asarray(0), value=value, aux=None, # we do not need to have aux # in the initial state error=jnp.asarray(jnp.inf), params=params, num_fun_eval=num_fun_eval, num_grad_eval=num_grad_eval, done=jnp.asarray(False), grad=grad, failed=jnp.asarray(False))
[docs] def update( self, stepsize: float, state: NamedTuple, params: Any, value: Optional[float] = None, grad: Optional[Any] = None, descent_direction: Optional[Any] = None, fun_args: list = [], fun_kwargs: dict = {}, ) -> base.LineSearchStep: """Performs one iteration of backtracking line search. Args: stepsize: current estimate of the step size. state: named tuple containing the line search state. params: current parameters. value: current function value (recomputed if None). grad: current gradient (recomputed if None). descent_direction: descent direction (negative gradient if None). fun_args: additional positional arguments to be passed to ``fun``. fun_kwargs: additional keyword arguments to be passed to ``fun``. Returns: (params, state) """ # Ensure that stepsize does not exceed upper bound. stepsize = jnp.minimum(self.max_stepsize, stepsize) num_fun_eval = state.num_fun_eval num_grad_eval = state.num_grad_eval if value is None or grad is None: if self.has_aux: (value, _), grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs) else: value, grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs) num_fun_eval += 1 num_grad_eval += 1 if descent_direction is None: descent_direction = tree_scalar_mul(-1, tree_conj(grad)) gd_vdot = tree_vdot_real(tree_conj(grad), descent_direction) # For backtracking linesearches, we want to compute the next point # from the basepoint. i.e. x_i = x_0 + s_i * p new_params = tree_add_scalar_mul(params, stepsize, descent_direction) if self.has_aux: (new_value, new_aux), new_grad = self._value_and_grad_fun( new_params, *fun_args, **fun_kwargs ) else: new_value, new_grad = self._value_and_grad_fun( new_params, *fun_args, **fun_kwargs ) new_aux = None new_gd_vdot = tree_vdot_real(tree_conj(new_grad), descent_direction) # Every condition requires the new function value, but not every one # requires the new gradient value (we'll assume that this code is called # under `jit`). num_fun_eval += 1 # Armijo condition (upper bound on admissible step size). # cond1 = new_value <= value + self.c1 * stepsize * gd_vdot # See equation (3.6a), Numerical Optimization, Second edition. diff_cond1 = new_value - (value + self.c1 * stepsize * gd_vdot) error_cond1 = jnp.maximum(diff_cond1, 0.0) error = error_cond1 if self.condition == "armijo": # We don't need to do any extra work, since this is covered by the # sufficient decrease condition above. pass elif self.condition == "strong-wolfe": # cond2 = abs(new_gd_vdot) <= c2 * abs(gd_vdot) # See equation (3.7b), Numerical Optimization, Second edition. diff_cond2 = jnp.abs(new_gd_vdot) - self.c2 * jnp.abs(gd_vdot) error_cond2 = jnp.maximum(diff_cond2, 0.0) error = jnp.maximum(error_cond1, error_cond2) num_grad_eval += 1 elif self.condition == "wolfe": # cond2 = new_gd_vdot >= c2 * gd_vdot # See equation (3.6b), Numerical Optimization, Second edition. diff_cond2 = self.c2 * gd_vdot - new_gd_vdot error_cond2 = jnp.maximum(diff_cond2, 0.0) error = jnp.maximum(error_cond1, error_cond2) num_grad_eval += 1 elif self.condition == "goldstein": # cond2 = new_value >= value + (1 - self.c1) * stepsize * gd_vdot diff_cond2 = value + (1 - self.c1) * stepsize * gd_vdot - new_value error_cond2 = jnp.maximum(diff_cond2, 0.0) error = jnp.maximum(error_cond1, error_cond2) else: raise ValueError("condition should be one of " "'armijo', 'goldstein', 'strong-wolfe' or 'wolfe'.") new_stepsize = jnp.where(error <= self.tol, stepsize, stepsize * self.decrease_factor) done = state.done | (error <= self.tol) failed = state.failed | ((state.iter_num + 1 == self.maxiter) & ~done) new_state = BacktrackingLineSearchState(iter_num=state.iter_num + 1, value=new_value, aux=new_aux, grad=new_grad, params=new_params, num_fun_eval=num_fun_eval, num_grad_eval=num_grad_eval, done=done, error=error, failed=failed) return base.LineSearchStep(stepsize=new_stepsize, state=new_state)
def __post_init__(self): if self.value_and_grad: self._value_and_grad_fun = self.fun else: self._value_and_grad_fun = jax.value_and_grad( self.fun, has_aux=self.has_aux )