# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hager-Zhang line search algorithm."""
# This is based on:
# [1] W. Hager, H. Zhang, A new conjugate gradient method with guaranteed
# descent and an efficient line search. SIAM J. Optim., Vol 16. 1, pp. 170-172.
# 2005. https://www.math.lsu.edu/~hozhang/papers/cg_descent.pdf
#
# Algorithm details are from
# [2] W. Hager, H. Zhang, Algorithm 851: CG_DESCENT, a Conjugate Gradient Method
# with Guaranteed Descent.
# https://www.math.lsu.edu/~hozhang/papers/cg_compare.pdf
from dataclasses import dataclass
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional
import jax
import jax.numpy as jnp
from jaxopt._src import base
from jaxopt.tree_util import tree_add_scalar_mul
from jaxopt.tree_util import tree_scalar_mul
from jaxopt.tree_util import tree_vdot_real
from jaxopt.tree_util import tree_conj
def _failed_nan(value, grad):
return jnp.isnan(value) | jnp.isnan(grad)
class HagerZhangLineSearchState(NamedTuple):
"""Named tuple containing state information."""
iter_num: int
done: bool
low: float
high: float
value: float
error: float
params: Any
grad: Any
failed: bool
aux: Optional[Any] = None
num_fun_eval: int = 0
num_grad_eval: int = 0
[docs]@dataclass(eq=False)
class HagerZhangLineSearch(base.IterativeLineSearch):
"""Hager-Zhang line search.
Supports complex variables.
Attributes:
fun: a function of the form ``fun(params, *args, **kwargs)``, where
``params`` are parameters of the model,
``*args`` and ``**kwargs`` are additional arguments.
value_and_grad: if ``False``, ``fun`` should return the function value only.
If ``True``, ``fun`` should return both the function value and the
gradient.
has_aux: if ``False``, ``fun`` should return the function value only.
If ``True``, ``fun`` should return a pair ``(value, aux)`` where ``aux``
is a pytree of auxiliary values.
c1: constant used by the Wolfe and Approximate Wolfe condition.
c2: constant strictly less than 1 used by the Wolfe and Approximate Wolfe
condition.
max_stepsize: upper bound on stepsize.
maxiter: maximum number of line search iterations.
tol: tolerance of the stopping criterion.
verbose: whether to print error on every iteration or not. verbose=True will
automatically disable jit.
jit: whether to JIT-compile the optimization loop (default: "auto").
unroll: whether to unroll the optimization loop (default: "auto").
"""
fun: Callable # pylint:disable=g-bare-generic
value_and_grad: bool = False
has_aux: bool = False
maxiter: int = 30
tol: float = 0.
c1: float = 0.1
c2: float = 0.9
expansion_factor: float = 5.0
shrinkage_factor: float = 0.66
approximate_wolfe_threshold = 1e-6
max_stepsize: float = 1.0
verbose: int = 0
jit: base.AutoOrBoolean = "auto"
unroll: base.AutoOrBoolean = "auto"
def _value_and_grad_on_line(self, x, c, descent_direction, *args, **kwargs):
z = tree_add_scalar_mul(x, c, descent_direction)
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(z, *args, **kwargs)
else:
value, grad = self._value_and_grad_fun(z, *args, **kwargs)
return value, tree_vdot_real(tree_conj(grad), descent_direction)
def _satisfies_wolfe_and_approx_wolfe(
self,
c,
value_c,
gd_vdot_c,
value_initial,
grad_initial,
approx_wolfe_threshold_value,
descent_direction):
gd_vdot = tree_vdot_real(tree_conj(grad_initial), descent_direction)
# Armijo condition
# armijo = value_c <= value_initial + self.c1 * c * gd_vdot
armijo = value_c - (value_initial + self.c1 * c * gd_vdot)
armijo_error = jax.lax.max(armijo, 0.)
# Curvature condition
# curvature = gd_vdot_c >= self.c2 * gd_vdot
curvature = self.c2 * gd_vdot - gd_vdot_c
wolfe_error = jax.lax.max(armijo_error, curvature)
# Approximate Wolfe
# approx_wolfe = (2 * self.c1 - 1.) * gd_vdot >= gd_vdot_c
approx_wolfe = gd_vdot_c - (2 * self.c1 - 1.) * gd_vdot
approx_wolfe_error = jax.lax.max(approx_wolfe, 0.)
approx_wolfe_error = jax.lax.max(approx_wolfe_error, curvature)
# Finally only enable the approximate wolfe conditions when we are close in
# value.
approx_wolfe_error = jax.lax.max(
approx_wolfe_error,
(value_c - approx_wolfe_threshold_value))
# We succeed if we either satisfy the Wolfe conditions or the approximate
# Wolfe conditions.
return jax.lax.min(wolfe_error, approx_wolfe_error)
def _update(
self, x, low, high, middle, approx_wolfe_threshold_value,
descent_direction, fun_args, fun_kwargs):
value_middle, grad_middle = self._value_and_grad_on_line(
x, middle, descent_direction, *fun_args, **fun_kwargs)
# Corresponds to the `update` subroutine in the paper.
# This tries to create a smaller interval contained in `[low, high]`
# from the point `middle` that satisfies the opposite slope condition, where
# the left end point is equal to within tolerance of the initial value.
def cond_fn(state):
done, failed, low, middle, high, *_ = state
return jnp.any((middle < high) & (middle > low) & ~done & ~failed)
def body_fn(state):
done, failed, low, middle, high, value_middle, grad_middle, it = state
# Correspond to U1 in the paper.
update_right_endpoint = grad_middle >= 0.
new_high = jnp.where(~done & update_right_endpoint, middle, high)
done = done | update_right_endpoint
# Correspond to U2 in the paper.
update_left_endpoint = value_middle <= approx_wolfe_threshold_value
# Note that ~done implies grad_middle < 0. which is necessary for this
# check.
new_low = jnp.where(~done & update_left_endpoint, middle, low)
done = done | update_left_endpoint
# Correspond to U3 in the paper.
new_high = jnp.where(~done, middle, new_high)
done = done | jnp.isneginf(value_middle)
# TODO(srvasude): Allow this parameter to be varied.
new_middle = jnp.where(~done, (low + high) / 2., middle)
new_value_middle, new_grad_middle = self._value_and_grad_on_line(
x, new_middle, descent_direction, *fun_args, **fun_kwargs)
new_value_middle = jnp.where(~done, new_value_middle, value_middle)
new_grad_middle = jnp.where(~done, new_grad_middle, grad_middle)
failed = failed | _failed_nan(new_value_middle, new_grad_middle)
return (done,
failed,
new_low,
new_middle,
new_high,
new_value_middle,
new_grad_middle,
it + 1)
_, failed, final_low, _, final_high, _, _, nit = jax.lax.while_loop(
cond_fn,
body_fn,
((middle >= high) | (middle <= low),
_failed_nan(value_middle, grad_middle),
low,
middle,
high,
value_middle,
grad_middle,
0))
num_fun_grad_calls = nit + 1
return failed, final_low, final_high, num_fun_grad_calls
def _secant(self, x, low, high, descent_direction, *args, **kwargs):
_, dlow = self._value_and_grad_on_line(
x, low, descent_direction, *args, **kwargs)
_, dhigh = self._value_and_grad_on_line(
x, high, descent_direction, *args, **kwargs)
return (low * dhigh - high * dlow) / (dhigh - dlow)
def _secant2(
self, x, low, high,
approx_wolfe_threshold_value, descent_direction, *args, **kwargs):
# Corresponds to the secant^2 routine in the paper.
c = self._secant(x, low, high, descent_direction, *args, **kwargs)
num_fun_grad_calls = 2
failed, new_low, new_high, num_fun_grad_calls_update = self._update(
x, low, high, c, approx_wolfe_threshold_value,
descent_direction, args, kwargs)
num_fun_grad_calls += num_fun_grad_calls_update
on_left_boundary = jnp.equal(c, new_low)
on_right_boundary = jnp.equal(c, new_high)
c = jnp.where(on_right_boundary, self._secant(
x, high, new_high, descent_direction, *args, **kwargs), c)
c = jnp.where(on_left_boundary, self._secant(
x, low, new_low, descent_direction, *args, **kwargs), c)
num_fun_grad_calls += 4
def _reupdate():
return self._update(
x, new_low, new_high, c, approx_wolfe_threshold_value,
descent_direction, args, kwargs)
failed, new_low, new_high, num_fun_grad_calls_update = jax.lax.cond(
on_left_boundary | on_right_boundary,
_reupdate, lambda: (failed, new_low, new_high, 0))
num_fun_grad_calls += num_fun_grad_calls_update
return failed, new_low, new_high, num_fun_grad_calls
def _bracket(
self, x, c, approx_wolfe_threshold_value,
descent_direction, *args, **kwargs):
# Initial interval that satisfies the opposite slope condition.
def cond_fn(state):
return jnp.any(~state[0]) & ~jnp.all(state[1])
def body_fn(state):
(done,
failed,
low,
middle,
high,
value_middle,
grad_middle,
best_middle,
num_fun_grad_calls) = state
# Correspond to B1 in the paper.
update_right_endpoint = grad_middle >= 0.
new_high = jnp.where(~done & update_right_endpoint, middle, high)
new_low = jnp.where(~done & update_right_endpoint, best_middle, low)
done = done | update_right_endpoint
# Correspond to B2 in the paper.
# Note that ~done implies grad_middle < 0. at this point so we omit
# checking that.
reupdate = ~done & (value_middle > approx_wolfe_threshold_value)
def _update_interval():
return self._update(
x,
0,
middle,
middle / 2.,
approx_wolfe_threshold_value,
descent_direction, args, kwargs)
new_failed, new_low, new_high, new_num_fun_grad_calls = jax.lax.cond(
reupdate, _update_interval, lambda: (failed, new_low, new_high, 0))
failed = failed | new_failed
done = done | reupdate
# This corresponds to the largest middle value that we have probed
# so far, that also is 'valid' (decreases the function sufficiently).
best_middle = jnp.where(
~done & (value_middle <= approx_wolfe_threshold_value),
middle, best_middle)
# Corresponds to B3 in the paper. Increase the point and recompute.
new_middle = jnp.where(~done, self.expansion_factor * middle, middle)
new_value_middle, new_grad_middle = self._value_and_grad_on_line(
x, new_middle, descent_direction, *args, **kwargs)
new_num_fun_grad_calls += 1
num_fun_grad_calls += new_num_fun_grad_calls
# Terminate on encountering NaNs to avoid an infinite loop.
failed = failed | _failed_nan(new_value_middle, new_grad_middle)
return (done,
failed,
new_low,
new_middle,
new_high,
new_value_middle,
new_grad_middle,
best_middle,
num_fun_grad_calls)
value_c, grad_c = self._value_and_grad_on_line(
x, c, descent_direction, *args, **kwargs)
num_fun_grad_calls = 1
# We have failed if there is a NaN at the right endpoint, or the gradient is
# NaN at the right endpoint (when there is a finite value).
failed = _failed_nan(value_c, grad_c)
# If the right endpoint is -inf, then we are done as this is a minima.
done = jnp.isneginf(value_c)
_, failed, final_low, _, final_high, _, _, _, new_num_fun_grad_calls = jax.lax.while_loop(
cond_fn,
body_fn,
(done,
failed,
jnp.array(0.),
c,
c,
value_c,
grad_c,
jnp.array(0.),
0))
num_fun_grad_calls += new_num_fun_grad_calls
return failed, final_low, final_high, num_fun_grad_calls
[docs] def init_state(
self,
init_stepsize: float,
params: Any,
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list = [],
fun_kwargs: dict = {},
) -> HagerZhangLineSearchState:
"""Initialize the line search state.
Args:
init_stepsize: initial step size value. This is ignored by the linesearch.
params: current parameters.
value: current function value (recomputed if None).
grad: current gradient (recomputed if None).
descent_direction: ignored.
fun_args: additional positional arguments to be passed to ``fun``.
fun_kwargs: additional keyword arguments to be passed to ``fun``.
Returns:
state
"""
del init_stepsize
if value is None or grad is None:
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(
params, *fun_args, **fun_kwargs
)
else:
value, grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs)
num_fun_eval = 1
num_grad_eval = 1
else:
num_fun_eval = 0
num_grad_eval = 0
if descent_direction is None:
descent_direction = tree_scalar_mul(-1, tree_conj(grad))
approx_wolfe_threshold_value = (
value + self.approximate_wolfe_threshold * jnp.abs(value))
# Create initial interval.
failed, low, high, num_fun_grad_calls = self._bracket(
params,
jnp.ones_like(value),
approx_wolfe_threshold_value,
descent_direction,
*fun_args,
**fun_kwargs
)
num_fun_eval += num_fun_grad_calls
num_grad_eval += num_fun_grad_calls
value_low, grad_low = self._value_and_grad_on_line(
params, low, descent_direction, *fun_args, **fun_kwargs)
value_high, grad_high = self._value_and_grad_on_line(
params, high, descent_direction, *fun_args, **fun_kwargs)
num_fun_eval += 2
num_grad_eval += 2
best_point = jnp.where(value_low < value_high, low, high)
gd_vdot_best_point = jnp.where(value_low < value_high, grad_low, grad_high)
value_best_point = jnp.minimum(value_low, value_high)
error = self._satisfies_wolfe_and_approx_wolfe(
best_point,
value_best_point,
gd_vdot_best_point,
value,
grad,
approx_wolfe_threshold_value,
descent_direction)
done = error <= self.tol
return HagerZhangLineSearchState(
iter_num=jnp.asarray(0),
low=low,
high=high,
error=error,
done=done,
value=value,
aux=None, # we do not need to have aux in the initial state
params=params,
grad=grad,
failed=failed,
num_fun_eval=jnp.array(num_fun_eval, base.NUM_EVAL_DTYPE),
num_grad_eval=jnp.array(num_grad_eval, base.NUM_EVAL_DTYPE))
[docs] def update(
self,
stepsize: float,
state: NamedTuple,
params: Any,
value: Optional[float] = None,
grad: Optional[Any] = None,
descent_direction: Optional[Any] = None,
fun_args: list = [],
fun_kwargs: dict = {},
) -> base.LineSearchStep:
"""Performs one iteration of Hager-Zhang line search.
Args:
stepsize: current estimate of the step size.
state: named tuple containing the line search state.
params: current parameters.
value: current function value (recomputed if None).
grad: current gradient (recomputed if None).
descent_direction: descent direction (negative gradient if None).
fun_args: additional positional arguments to be passed to ``fun``.
fun_kwargs: additional keyword arguments to be passed to ``fun``.
Returns:
(params, state)
"""
if value is None or grad is None:
if self.has_aux:
(value, _), grad = self._value_and_grad_fun(
params, *fun_args, **fun_kwargs
)
else:
value, grad = self._value_and_grad_fun(params, *fun_args, **fun_kwargs)
new_num_fun_eval = state.num_fun_eval + 1
new_num_grad_eval = state.num_grad_eval + 1
if descent_direction is None:
descent_direction = tree_scalar_mul(-1, tree_conj(grad))
approx_wolfe_threshold_value = (
value + self.approximate_wolfe_threshold * jnp.abs(value))
failed, new_low, new_high, num_fun_grad_calls = self._secant2(
params,
state.low,
state.high,
approx_wolfe_threshold_value,
descent_direction,
*fun_args, **fun_kwargs)
new_num_fun_eval += num_fun_grad_calls
new_num_grad_eval += num_fun_grad_calls
failed = state.failed | failed
new_low = jnp.where(state.done, state.low, new_low)
new_high = jnp.where(state.done, state.high, new_high)
def _reupdate():
c = (new_low + new_high) / 2.
return self._update(
params, new_low, new_high, c, approx_wolfe_threshold_value,
descent_direction, fun_args, fun_kwargs)
failed, new_low, new_high, num_fun_grad_calls = jax.lax.cond(
~state.done & ((new_high - new_low) >
(self.shrinkage_factor * (state.high - state.low))),
_reupdate, lambda: (failed, new_low, new_high, 0))
new_num_fun_eval += num_fun_grad_calls
new_num_grad_eval += num_fun_grad_calls
# Check wolfe and approximate wolfe conditions and update them.
value_low, grad_low = self._value_and_grad_on_line(
params, new_low, descent_direction, *fun_args, **fun_kwargs)
value_high, grad_high = self._value_and_grad_on_line(
params, new_high, descent_direction, *fun_args, **fun_kwargs)
new_num_fun_eval += 2
new_num_grad_eval += 2
best_point = jnp.where(value_low < value_high, new_low, new_high)
gd_vdot_best_point = jnp.where(value_low < value_high, grad_low, grad_high)
value_best_point = jnp.minimum(value_low, value_high)
new_stepsize = jnp.where(state.done, stepsize, best_point)
new_params = tree_add_scalar_mul(params, best_point, descent_direction)
if self.has_aux:
(new_value, new_aux), new_grad = self._value_and_grad_fun(
new_params, *fun_args, **fun_kwargs)
else:
new_value, new_grad = self._value_and_grad_fun(
new_params, *fun_args, **fun_kwargs)
new_aux = None
new_num_fun_eval += 1
new_num_grad_eval += 1
error = jnp.where(state.done, state.error,
self._satisfies_wolfe_and_approx_wolfe(
best_point,
value_best_point,
gd_vdot_best_point,
value,
grad,
approx_wolfe_threshold_value,
descent_direction))
done = state.done | (error <= self.tol)
failed = failed | ((state.iter_num + 1 == self.maxiter) & ~done)
new_state = HagerZhangLineSearchState(
iter_num=state.iter_num + 1,
value=new_value,
grad=new_grad,
aux=new_aux,
params=new_params,
low=new_low,
high=new_high,
error=error,
done=done,
failed=failed,
num_fun_eval=new_num_fun_eval,
num_grad_eval=new_num_grad_eval)
return base.LineSearchStep(stepsize=new_stepsize, state=new_state)
def __post_init__(self):
if self.value_and_grad:
self._value_and_grad_fun = self.fun
else:
self._value_and_grad_fun = jax.value_and_grad(
self.fun, has_aux=self.has_aux
)