# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wraps SciPy's optimization routines with PyTree and implicit diff support.
# TODO(fllinares): add support for `LinearConstraint`s.
# TODO(fllinares): add support for methods requiring Hessian / Hessian prods.
# TODO(fllinares): possibly hardcode `dtype` attribute, as likely useless.
# TODO(pedregosa): add a 'maxiter' and 'callback' keyword option for all wrappers,
# currently only ScipyMinimize exposes this option.
"""
import abc
import dataclasses
from dataclasses import dataclass
from typing import Any
from typing import Callable
from typing import Dict
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import jax
from jax.config import config
import jax.numpy as jnp
import jax.tree_util as tree_util
from jax.tree_util import register_pytree_node_class
from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src import projection
from jaxopt._src.tree_util import tree_sub
import numpy as onp
import scipy as osp
from scipy.optimize import LbfgsInvHessProduct
@register_pytree_node_class
class LbfgsInvHessProductPyTree(LbfgsInvHessProduct):
"""
Registers the LbfgsInvHessProduct object as a PyTree.
This object is typically returned by the L-BFSG-B optimizer to efficiently
store the inverse of the Hessian matrix evaluated at the best-fit parameters.
"""
def __init__(self, sk, yk):
"""
Construct the operator.
This is the same constructor as the original LbfgsInvHessProduct class,
except that numpy has been replaced by jax.numpy and no call to the
numpy.ndarray constuctor is performed.
"""
if sk.shape != yk.shape or sk.ndim != 2:
raise ValueError('sk and yk must have matching shape, (n_corrs, n)')
n_corrs, n = sk.shape
self.dtype = jnp.float64 if config.jax_enable_x64 is True else jnp.float32
self.shape = (n, n)
self.sk = sk
self.yk = yk
self.n_corrs = n_corrs
self.rho = 1 / jnp.einsum('ij,ij->i', sk, yk)
def __repr__(self):
return "LbfgsInvHessProduct(sk={}, yk={})".format(self.sk, self.yk)
def tree_flatten(self):
children = (self.sk, self.yk)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
class ScipyMinimizeInfo(NamedTuple):
"""Named tuple with results for `scipy.optimize.minimize` wrappers."""
fun_val: jnp.ndarray
success: bool
status: int
iter_num: int
hess_inv: Optional[Union[jnp.ndarray, LbfgsInvHessProductPyTree]]
num_fun_eval: int = 0
num_jac_eval: int = 0
num_hess_eval: int = 0
class ScipyRootInfo(NamedTuple):
"""Named tuple with results for `scipy.optimize.root` wrappers."""
fun_val: float
success: bool
status: int
iter_num: int
num_fun_eval: int = 0
class ScipyLeastSquaresInfo(NamedTuple):
"""Named tuple with results for `scipy.optimize.least_squares` wrappers."""
cost_val: float
fun_val: jnp.ndarray
success: bool
status: int
num_fun_eval: int
num_jac_eval: Optional[int]
error: float
class PyTreeTopology(NamedTuple):
"""Stores info to reconstruct PyTree from flattened PyTree leaves.
# TODO(fllinares): more specific type annotations for attributes?
Attributes:
treedef: the PyTreeDef object encoding the structure of the target PyTree.
shapes: an iterable with the shapes of each leaf in the target PyTree.
dtypes: an iterable with the dtypes of each leaf in the target PyTree.
sizes: an iterable with the sizes of each leaf in the target PyTree.
n_leaves: the number of leaves in the target PyTree.
"""
treedef: Any
shapes: Sequence[Any]
dtypes: Sequence[Any]
@property
def sizes(self):
return [int(onp.prod(shape)) for shape in self.shapes]
@property
def n_leaves(self):
return len(self.shapes)
def jnp_to_onp(x_jnp: Any,
dtype: Optional[Any] = onp.float64) -> onp.ndarray:
"""Converts JAX PyTree into repr suitable for scipy.optimize.minimize.
Several of SciPy's optimization routines require inputs and/or outputs to be
onp.ndarray<float>[n]. Given an input PyTree `x_jnp`, this function will
flatten all its leaves and, if there is more than one leaf, the corresponding
flattened arrays will be concatenated and, optionally, casted to `dtype`.
Args:
x_jnp: a PyTree of jnp.ndarray with structure identical to init.
dtype: if not None, ensure output is a NumPy array of this dtype.
Returns:
A single onp.ndarray<dtype>[n] array, consisting of all leaves of x_jnp
flattened and concatenated. If dtype is None, the output dtype will be
determined by NumPy's casting rules for the concatenate method.
"""
x_onp = [onp.asarray(leaf, dtype).reshape(-1)
for leaf in tree_util.tree_leaves(x_jnp)]
# NOTE(fllinares): return value must *not* be read-only, I believe.
return onp.concatenate(x_onp)
def make_jac_jnp_to_onp(input_pytree_topology: PyTreeTopology,
output_pytree_topology: PyTreeTopology,
dtype: Optional[Any] = onp.float64) -> Callable:
"""Returns function "flattening" Jacobian for given in/out PyTree topologies.
For a smooth function `fun(x_jnp, *args, **kwargs)` taking an arbitrary
PyTree `x_jnp` as input and returning another arbitrary PyTree `y_jnp` as
output, JAX's transforms such as `jax.jacrev` or `jax.jacfwd` will return a
Jacobian with a PyTree structure reflecting the input and output PyTrees.
However, several of SciPy's optimization routines expect inputs and outputs to
be 1D NumPy arrays and, thus, Jacobians to be 2D NumPy arrays.
Given the Jacobian of `fun(x_jnp, *args, **kwargs)` as provided by JAX,
`jac_jnp_to_onp` will format it to match the Jacobian of
`jnp_to_onp(fun(x_jnp, *args, **kwargs))` w.r.t. `jnp_to_onp(x_jnp)`,
where `jnp_to_onp` is a vectorization operator for arbitrary PyTrees.
Args:
input_pytree_topology: a PyTreeTopology encoding the topology of the input
PyTree.
output_pytree_topology: a PyTreeTopology encoding the topology of the output
PyTree.
dtype: if not None, ensure output is a NumPy array of this dtype.
Returns:
A function "flattening" Jacobian for given input and output PyTree
topologies.
"""
ravel_index = lambda i, j: j + i * input_pytree_topology.n_leaves
def jac_jnp_to_onp(jac_pytree: Any):
# Builds flattened Jacobian blocks such that `jacs_onp[i][j]` equals the
# Jacobian of vec(i-th leaf of output_pytree) w.r.t.
# vec(j-th leaf of input_pytree), where vec() is the vectorization op.,
# i.e. reshape(input, [-1]).
jacs_leaves = tree_util.tree_leaves(jac_pytree)
jacs_onp = []
for i, output_size in enumerate(output_pytree_topology.sizes):
jacs_onp_i = []
for j, input_size in enumerate(input_pytree_topology.sizes):
jac_leaf = onp.asarray(jacs_leaves[ravel_index(i, j)], dtype)
jac_leaf = jac_leaf.reshape([output_size, input_size])
jacs_onp_i.append(jac_leaf)
jacs_onp.append(jacs_onp_i)
return onp.block(jacs_onp)
return jac_jnp_to_onp
def make_onp_to_jnp(pytree_topology: PyTreeTopology) -> Callable:
"""Returns inverse of `jnp_to_onp` for a specific PyTree topology.
Args:
pytree_topology: a PyTreeTopology encoding the topology of the original
PyTree to be reconstructed.
Returns:
The inverse of `jnp_to_onp` for a specific PyTree topology.
"""
treedef, shapes, dtypes = pytree_topology
split_indices = onp.cumsum(list(pytree_topology.sizes[:-1]))
def onp_to_jnp(x_onp: onp.ndarray) -> Any:
"""Inverts `jnp_to_onp` for a specific PyTree topology."""
flattened_leaves = onp.split(x_onp, split_indices)
x_jnp = [jnp.asarray(leaf.reshape(shape), dtype)
for leaf, shape, dtype in zip(flattened_leaves, shapes, dtypes)]
return tree_util.tree_unflatten(treedef, x_jnp)
return onp_to_jnp
def pytree_topology_from_example(x_jnp: Any) -> PyTreeTopology:
"""Returns a PyTreeTopology encoding the PyTree structure of `x_jnp`."""
leaves, treedef = tree_util.tree_flatten(x_jnp)
shapes = [jnp.asarray(leaf).shape for leaf in leaves]
dtypes = [jnp.asarray(leaf).dtype for leaf in leaves]
return PyTreeTopology(treedef=treedef, shapes=shapes, dtypes=dtypes)
@dataclass(eq=False)
class ScipyWrapper(base.Solver):
"""Wraps over `scipy.optimize` methods with PyTree and implicit diff support.
Attributes:
method: the `method` argument for `scipy.optimize`.
maxiter: Maximum number of iterations to perform. Depending on the method,
each iteration may use several function evaluations.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
methods relying on FORTRAN code, such as the `L-BFGS-B` solver for
`scipy.optimize.minimize`, require casting to float64.
jit: whether to JIT-compile JAX-based values and grad evals.
implicit_diff_solve: the linear system solver to use.
has_aux: whether function `fun` outputs one (False) or more values (True).
When True it will be assumed by default that `fun(...)[0]` is the
objective.
"""
method: Optional[str] = None
dtype: Optional[Any] = onp.float64
jit: bool = True
implicit_diff_solve: Optional[Callable] = None
has_aux: bool = False
def optimality_fun(self, sol, *args, **kwargs):
raise NotImplementedError(
'ScipyWrapper subclasses must implement `optimality_fun` as needed.')
def __post_init__(self):
# Set up implicit diff.
decorator = idf.custom_root(self.optimality_fun,
has_aux=True,
solve=self.implicit_diff_solve)
# pylint: disable=g-missing-from-attributes
self.run = decorator(self.run)
[docs]@dataclass(eq=False)
class ScipyMinimize(ScipyWrapper):
"""`scipy.optimize.minimize` wrapper
This wrapper is for unconstrained minimization only.
It supports pytrees and implicit diff.
Attributes:
fun: a smooth function of the form `fun(x, *args, **kwargs)`.
method: the `method` argument for `scipy.optimize.minimize`. Should be one
of * 'Nelder-Mead' * 'Powell' * 'CG' * 'BFGS' * 'Newton-CG' * 'L-BFGS-B' *
'TNC' * 'COBYLA' * 'SLSQP' * 'trust-constr' * 'dogleg' * 'trust-ncg' *
'trust-exact' * 'trust-krylov'
tol: the `tol` argument for `scipy.optimize.minimize`.
options: the `options` argument for `scipy.optimize.minimize`.
callback: called after each iteration, as callback(xk), where xk is the
current parameter vector.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
methods relying on FORTRAN code, such as the `L-BFGS-B` solver for
`scipy.optimize.minimize`, require casting to float64.
jit: whether to JIT-compile JAX-based values and grad evals.
implicit_diff_solve: the linear system solver to use.
has_aux: whether function `fun` outputs one (False) or more values (True).
When True it will be assumed by default that `fun(...)[0]` is the
objective.
value_and_grad: See base.make_funs_with_aux for more detail.
"""
fun: Callable = None
callback: Callable = None
tol: Optional[float] = None
options: Optional[Dict[str, Any]] = None
maxiter: int = 500
value_and_grad: Union[bool, Callable] = False
[docs] def optimality_fun(self, sol, *args, **kwargs):
"""Optimality function mapping compatible with `@custom_root`."""
return self._grad_fun(sol, *args, **kwargs)
def _run(self, init_params, bounds, *args, **kwargs):
"""Wraps `scipy.optimize.minimize`."""
# Sets up the "JAX-SciPy" bridge.
pytree_topology = pytree_topology_from_example(init_params)
onp_to_jnp = make_onp_to_jnp(pytree_topology)
# wrap the callback so its arguments are of the same kind as fun
if self.callback is not None:
def scipy_callback(x_onp: onp.ndarray):
x_jnp = onp_to_jnp(x_onp)
return self.callback(x_jnp)
else:
scipy_callback = None
def scipy_fun(x_onp: onp.ndarray) -> Tuple[onp.ndarray, onp.ndarray]:
x_jnp = onp_to_jnp(x_onp)
value, grads = self._value_and_grad_fun(x_jnp, *args, **kwargs)
return onp.asarray(value, self.dtype), jnp_to_onp(grads, self.dtype)
if bounds is not None:
bounds = osp.optimize.Bounds(lb=jnp_to_onp(bounds[0], self.dtype),
ub=jnp_to_onp(bounds[1], self.dtype))
res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
jac=True,
tol=self.tol,
bounds=bounds,
method=self.method,
callback=scipy_callback,
options=self.options)
params = tree_util.tree_map(jnp.asarray, onp_to_jnp(res.x))
if hasattr(res, 'hess_inv'):
if isinstance(res.hess_inv, osp.optimize.LbfgsInvHessProduct):
hess_inv = LbfgsInvHessProductPyTree(res.hess_inv.sk,
res.hess_inv.yk)
elif isinstance(res.hess_inv, onp.ndarray):
hess_inv = jnp.asarray(res.hess_inv)
else:
hess_inv = None
try:
num_hess_eval = jnp.asarray(res.nhev, base.NUM_EVAL_DTYPE)
except AttributeError:
num_hess_eval = jnp.array(0, base.NUM_EVAL_DTYPE)
info = ScipyMinimizeInfo(fun_val=jnp.asarray(res.fun),
success=res.success,
status=res.status,
iter_num=res.nit,
hess_inv=hess_inv,
num_fun_eval=jnp.asarray(res.nfev, base.NUM_EVAL_DTYPE),
num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE),
num_hess_eval=num_hess_eval)
return base.OptStep(params, info)
[docs] def run(self,
init_params: Any,
*args,
**kwargs) -> base.OptStep:
"""Runs the solver.
Args:
init_params: pytree containing the initial parameters.
*args: additional positional arguments to be passed to `fun`.
**kwargs: additional keyword arguments to be passed to `fun`.
Returns:
(params, info).
"""
return self._run(init_params, None, *args, **kwargs)
def __post_init__(self):
super().__post_init__()
self.fun, self._grad_fun, self._value_and_grad_fun = (
base._make_funs_without_aux(self.fun, self.value_and_grad, self.has_aux)
)
# Pre-compile useful functions.
if self.jit:
self.fun = jax.jit(self.fun)
self._grad_fun = jax.jit(self._grad_fun)
self._value_and_grad_fun = jax.jit(self._value_and_grad_fun)
if self.options is None:
self.options = {}
if 'maxiter' in self.options:
raise ValueError("Cannot pass maxiter through options dictionary, use maxiter keyword argument instead.")
self.options['maxiter'] = self.maxiter
[docs]@dataclass(eq=False)
class ScipyBoundedMinimize(ScipyMinimize):
"""`scipy.optimize.minimize` wrapper.
This wrapper is for minimization subject to box constraints only.
Attributes:
fun: a smooth function of the form `fun(x, *args, **kwargs)`.
method: the `method` argument for `scipy.optimize.minimize`.
tol: the `tol` argument for `scipy.optimize.minimize`.
options: the `options` argument for `scipy.optimize.minimize`.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
methods relying on FORTRAN code, such as the `L-BFGS-B` solver for
`scipy.optimize.minimize`, require casting to float64.
jit: whether to JIT-compile JAX-based values and grad evals.
implicit_diff_solve: the linear system solver to use.
has_aux: whether function `fun` outputs one (False) or more values (True).
When True it will be assumed by default that `fun(...)[0]` is the
objective.
"""
def _fixed_point_fun(self, sol, bounds, args, kwargs):
step = tree_sub(sol, self._grad_fun(sol, *args, **kwargs))
return projection.projection_box(step, bounds)
[docs] def optimality_fun(self, sol, bounds, *args, **kwargs):
"""Optimality function mapping compatible with `@custom_root`."""
fp = self._fixed_point_fun(sol, bounds, args, kwargs)
return tree_sub(fp, sol)
[docs] def run(self,
init_params: Any,
bounds: Optional[Any],
*args,
**kwargs) -> base.OptStep:
"""Runs the solver.
Args:
init_params: pytree containing the initial parameters.
bounds: an optional tuple `(lb, ub)` of pytrees with structure identical
to `init_params`, representing box constraints.
*args: additional positional arguments to be passed to `fun`.
**kwargs: additional keyword arguments to be passed to `fun`.
Returns:
(params, info).
"""
return self._run(init_params, bounds, *args, **kwargs)
[docs]@dataclass(eq=False)
class ScipyRootFinding(ScipyWrapper):
"""`scipy.optimize.root` wrapper.
It supports pytrees and implicit diff.
Attributes:
optimality_fun: a smooth vector function of the form
`optimality_fun(x, *args, **kwargs)` whose root is to be found. It must
return as output a PyTree with structure identical to x.
method: the `method` argument for `scipy.optimize.root`.
Should be one of
* 'hybr'
* 'lm'
* 'broyden1'
* 'broyden2'
* 'anderson'
* 'linearmixing'
* 'diagbroyden'
* 'excitingmixing'
* 'krylov'
* 'df-sane'
tol: the `tol` argument for `scipy.optimize.root`.
options: the `options` argument for `scipy.optimize.root`.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
methods relying on FORTRAN code, such as the `L-BFGS-B` solver for
`scipy.optimize.minimize`, require casting to float64.
jit: whether to JIT-compile JAX-based values and grad evals.
implicit_diff_solve: the linear system solver to use.
has_aux: whether function `fun` outputs one (False) or more values (True).
When True it will be assumed by default that `optimality_fun(...)[0]` is
the optimality function.
use_jacrev: whether to compute the Jacobian of `optimality_fun` using
`jax.jacrev` (True) or `jax.jacfwd` (False).
"""
optimality_fun: Callable = None
tol: Optional[float] = None
options: Optional[Dict[str, Any]] = None
use_jacrev: bool = True
[docs] def run(self,
init_params: Any,
*args,
**kwargs) -> base.OptStep:
"""Runs the solver.
Args:
init_params: pytree containing the initial parameters.
*args: additional positional arguments to be passed to `fun`.
**kwargs: additional keyword arguments to be passed to `fun`.
Returns:
(params, info).
"""
# Sets up the "JAX-SciPy" bridge.
pytree_topology = pytree_topology_from_example(init_params)
onp_to_jnp = make_onp_to_jnp(pytree_topology)
jac_jnp_to_onp = make_jac_jnp_to_onp(pytree_topology,
pytree_topology,
self.dtype)
def scipy_fun(x_onp: onp.ndarray, scipy_args: Any) -> Tuple[onp.ndarray, onp.ndarray]:
# scipy_args is unused but must appear in the signature since
# the `args` argument passed to osp.optimize.root is not None.
del scipy_args # unused
x_jnp = onp_to_jnp(x_onp)
value_jnp = self.optimality_fun(x_jnp, *args, **kwargs)
jacs_jnp = self._jac_fun(x_jnp, *args, **kwargs)
return jnp_to_onp(value_jnp, self.dtype), jac_jnp_to_onp(jacs_jnp)
# Argument `args` is unused but must be not None to ensure that some sanity checks are performed
# correctly in Scipy for optimizers that don't use the Jacobian (such as Broyden).
# See the related issue: https://github.com/google/jaxopt/issues/290
res = osp.optimize.root(scipy_fun, jnp_to_onp(init_params, self.dtype),
args=(None,),
jac=True,
tol=self.tol,
method=self.method,
options=self.options)
params = tree_util.tree_map(jnp.asarray, onp_to_jnp(res.x))
# NOTE: maybe there is a better way to do the following (zramzi)
if isinstance(res, osp.optimize.RootResults):
iter_num = jnp.array(res.iterations)
num_fun_eval = jnp.array(res.function_calls, base.NUM_EVAL_DTYPE)
else:
try:
iter_num = jnp.array(res.nit)
except AttributeError:
iter_num = None
try:
num_fun_eval = jnp.array(res.nfev, base.NUM_EVAL_DTYPE)
except AttributeError:
num_fun_eval = None
info = ScipyRootInfo(fun_val=jnp.asarray(res.fun),
success=res.success,
status=res.status,
iter_num=iter_num,
num_fun_eval=num_fun_eval)
return base.OptStep(params, info)
def __post_init__(self):
super().__post_init__()
if self.has_aux:
def optimality_fun(x, *args, **kwargs):
return self.optimality_fun(x, *args, **kwargs)[0]
self.optimality_fun = optimality_fun
# Pre-compile useful functions.
self._jac_fun = (jax.jacrev(self.optimality_fun) if self.use_jacrev
else jax.jacfwd(self.optimality_fun))
if self.jit:
self.optimality_fun = jax.jit(self.optimality_fun)
self._jac_fun = jax.jit(self._jac_fun)
# NOTE: relative to `scipy.optimize.least_squares`, the functions below absorb
# the squaring of residuals to avoid numerical issues for the gradient of the
# Huber loss at 0.
LS_RHO_FUNS = {
'linear': lambda z: z ** 2,
'soft_l1': lambda z: 2.0 * ((1 + z ** 2) ** 0.5 - 1),
'huber': lambda z: jnp.where(z <= 1, z ** 2, 2.0 * z - 1),
'cauchy': lambda z: jnp.log1p(z ** 2),
'arctan': lambda z: jnp.arctan(z ** 2),
}
LS_DEFAULT_OPTIONS = {
'ftol': 1e-8, # float
'xtol': 1e-8, # float
'gtol': 1e-8, # float
'x_scale': 1.0, # Any
'f_scale': 1.0, # float
'tr_solver': None, # Optional[str]
'tr_options': {}, # Optional[Dict[str, Any]]
'max_nfev': None, # Optional[int],
}
@dataclass(eq=False)
class ScipyLeastSquares(ScipyWrapper):
"""Wraps over `scipy.optimize.least_squares` with PyTree & imp. diff support.
This solver minimizes::
0.5 * sum(loss(fun(x, *args, **kwargs) ** 2)).
This wrapper is for unconstrained minimization only.
Attributes:
fun: a smooth function of the form ``fun(x, *args, **kwargs)`` computing the
residuals from the model parameters `x`.
loss: the `loss` argument for `scipy.optimize.least_squares`. However,
arbitrary losses specified with a Callable are not yet supported.
options: additional kwargs for `scipy.optimize.least_squares`.
method: the `method` argument for `scipy.optimize.least_squares`.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
methods relying on FORTRAN code, such as the `L-BFGS-B` solver for
`scipy.optimize.minimize`, require casting to float64.
jit: whether to JIT-compile JAX-based values and grad evals.
implicit_diff_solve: the linear system solver to use.
has_aux: whether function `fun` outputs one (False) or more values (True).
When True it will be assumed by default that `fun(...)[0]` are the
residuals.
use_jacrev: whether to compute the Jacobian of `fun` using `jax.jacrev`
(True) or `jax.jacfwd` (False).
"""
fun: Callable = None
loss: str = 'linear'
options: Optional[Dict[str, Any]] = None
use_jacrev: bool = True
def _cost_fun(self, params, *args, **kwargs):
residuals = self.fun(params, *args, **kwargs)
# NOTE: `self._rho` includes the squaring of residuals in its definition.
losses = self._rho(residuals / self.options['f_scale'])
return 0.5 * jnp.square(self.options['f_scale']) * jnp.mean(losses)
def optimality_fun(self, sol, *args, **kwargs):
"""Optimality function mapping compatible with `@custom_root`."""
return self._grad_cost_fun(sol, *args, **kwargs)
def _run(self, init_params, bounds, *args, **kwargs):
"""Wraps `scipy.optimize.least_squares`."""
# Sets up the "JAX-SciPy" bridge.
init_output = self.fun(init_params, *args, **kwargs)
input_pytree_topology = pytree_topology_from_example(init_params)
output_pytree_topology = pytree_topology_from_example(init_output)
onp_to_jnp = make_onp_to_jnp(input_pytree_topology)
jac_jnp_to_onp = make_jac_jnp_to_onp(input_pytree_topology,
output_pytree_topology,
self.dtype)
def scipy_fun(x_onp: onp.ndarray) -> onp.ndarray:
x_jnp = onp_to_jnp(x_onp)
value_jnp = self.fun(x_jnp, *args, **kwargs)
return jnp_to_onp(value_jnp, self.dtype)
def scipy_jac(x_onp: onp.ndarray) -> onp.ndarray:
x_jnp = onp_to_jnp(x_onp)
jacs_jnp = self._jac_fun(x_jnp, *args, **kwargs)
return jac_jnp_to_onp(jacs_jnp)
if bounds is not None:
bounds = (jnp_to_onp(bounds[0], self.dtype),
jnp_to_onp(bounds[1], self.dtype))
else:
bounds = (-onp.inf, onp.inf)
res = osp.optimize.least_squares(scipy_fun,
jnp_to_onp(init_params, self.dtype),
jac=scipy_jac,
bounds=bounds,
method=self.method,
loss=self.loss,
**self.options)
params = tree_util.tree_map(jnp.asarray, onp_to_jnp(res.x))
info = ScipyLeastSquaresInfo(cost_val=jnp.asarray(res.cost),
fun_val=jnp.asarray(res.fun),
success=res.success,
status=res.status,
num_fun_eval=jnp.asarray(res.nfev, base.NUM_EVAL_DTYPE),
num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE),
error=jnp.asarray(res.optimality))
return base.OptStep(params, info)
def run(self,
init_params: Any,
*args,
**kwargs) -> base.OptStep:
"""Runs the solver.
Args:
init_params: pytree containing the initial parameters.
*args: additional positional arguments to be passed to `fun`.
**kwargs: additional keyword arguments to be passed to `fun`.
Returns:
(params, info).
"""
return self._run(init_params, None, *args, **kwargs)
def __post_init__(self):
super().__post_init__()
if self.options is None:
self.options = LS_DEFAULT_OPTIONS
else:
for k, v in LS_DEFAULT_OPTIONS.items():
if k not in self.options:
self.options[k] = v
if self.has_aux:
self.fun = lambda x, *args, **kwargs: self.fun(x, *args, **kwargs)[0]
# Handles PyTree inputs for `x_scale` arg.
if self.options['x_scale'] != 'jac' and not isinstance(
self.options['x_scale'], float):
self.options['x_scale'] = jnp_to_onp(self.options['x_scale'], self.dtype)
# Pre-compile useful functions.
if self.loss not in LS_RHO_FUNS:
raise ValueError(f'`loss` must be one of {LS_RHO_FUNS.keys()}.')
self._rho = LS_RHO_FUNS[self.loss]
self._jac_fun = (jax.jacrev(self.fun) if self.use_jacrev
else jax.jacfwd(self.fun))
self._grad_cost_fun = jax.grad(self._cost_fun)
if self.jit:
self.fun = jax.jit(self.fun)
self._rho = jax.jit(self._rho)
self._jac_fun = jax.jit(self._jac_fun)
self._grad_cost_fun = jax.jit(self._grad_cost_fun)
@dataclass(eq=False)
class ScipyBoundedLeastSquares(ScipyLeastSquares):
"""Wraps over `scipy.optimize.least_squares` with PyTree & imp. diff support.
This solver minimizes::
0.5 * sum(loss(fun(x, *args, **kwargs) ** 2))
subject to bounds[0] <= x <= bounds[1].
This wrapper is for minimization subject to box constraints only.
Attributes:
fun: a smooth function of the form ``fun(x, *args, **kwargs)`` computing the
residuals from the model parameters `x`.
loss: the `loss` argument for `scipy.optimize.least_squares`. However,
arbitrary losses specified with a Callable are not yet supported.
options: additional kwargs for `scipy.optimize.least_squares`.
method: the `method` argument for `scipy.optimize.least_squares`.
dtype: if not None, cast all NumPy arrays to this dtype. Note that some
methods relying on FORTRAN code, such as the `L-BFGS-B` solver for
`scipy.optimize.minimize`, require casting to float64.
jit: whether to JIT-compile JAX-based values and grad evals.
implicit_diff_solve: the linear system solver to use.
has_aux: whether function `fun` outputs one (False) or more values (True).
When True it will be assumed by default that `fun(...)[0]` are the
residuals.
use_jacrev: whether to compute the Jacobian of `fun` using `jax.jacrev`
(True) or `jax.jacfwd` (False).
"""
def _fixed_point_fun(self, sol, bounds, args, kwargs):
step = tree_sub(sol, self._grad_cost_fun(sol, *args, **kwargs))
return projection.projection_box(step, bounds)
def optimality_fun(self, sol, bounds, *args, **kwargs):
"""Optimality function mapping compatible with `@custom_root`."""
fp = self._fixed_point_fun(sol, bounds, args, kwargs)
return tree_sub(fp, sol)
def run(self,
init_params: Any,
bounds: Optional[Any],
*args,
**kwargs) -> base.OptStep:
"""Runs the solver.
Args:
init_params: pytree containing the initial parameters.
bounds: an optional tuple `(lb, ub)` of pytrees with structure identical
to `init_params`, representing box constraints.
*args: additional positional arguments to be passed to `fun`.
**kwargs: additional keyword arguments to be passed to `fun`.
Returns:
(params, info).
"""
return self._run(init_params, bounds, *args, **kwargs)