Source code for jaxopt._src.objective

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common objective functions."""

from typing import Tuple

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src import loss


class CompositeLinearFunction:
  """A base class to represent composite linear functions.

  These are functions of the form::

    fun(params, *args, **kwargs) =
      subfun(linop(params), *args, **kwargs) + vdot(params, b(*args, **kwargs))

  where  ``linop = make_linop(*args, **kwargs)``.
  """

  def b(self, *args, **kwargs):
    """Linear term in the function."""
    return None

  def lipschitz_const(self, hyperparams):
    """Lipschitz-constant of subfun."""
    raise NotImplementedError

  def subfun(self, predictions, *args, **kwargs):
    """To be implemented by the child class."""
    raise NotImplementedError

  def __call__(self, params, *args, **kwargs):
    linop = self.make_linop(*args, **kwargs)
    predictions = linop.matvec(params)
    ret = self.subfun(predictions, *args, **kwargs)
    b = self.b(*args, **kwargs)
    if b is not None:
      ret += jnp.vdot(params, b)
    return ret


class LeastSquares(CompositeLinearFunction):
  """Least squares."""

  def subfun(self, predictions, data):
    y = data[1]
    residuals = predictions - y
    return 0.5 * jnp.mean(residuals ** 2)

  def make_linop(self, data):
    """Creates linear operator."""
    return base.LinearOperator(data[0])

  def columnwise_lipschitz_const(self, data):
    """Column-wise Lipschitz constants."""
    linop = self.make_linop(data)
    return linop.column_l2_norms(squared=True) * 1.0

  def __call__(self, W, data):
    r"""Least squares.

    .. math::

      \frac{1}{2n} ||XW - y||_2^2

    Args:
      W: parameters.
      data: a tuple ``(X, y)`` where ``X`` is a matrix of shape ``(n_samples,
        n_features)`` and ``y`` is a vector of shape ``(n_samples,)``.
    Returns:
      objective value.

    Example::

      value = least_squares(W, (X, y))
    """
    return super().__call__(W, data)


least_squares = LeastSquares()
least_squares.__doc__ = least_squares.__call__.__doc__


[docs]def ridge_regression( params: jnp.ndarray, l2reg: float, data: Tuple[jnp.ndarray, jnp.ndarray]) -> float: r""" Ridge regression, i.e L2-regularized least squares. .. math:: \frac{1}{2n} ||XW - y||_2^2 + 0.5 \cdot \text{l2reg} \cdot ||W||_2^2 Args: W: parameters. l2reg: strenght of regularization. data: a tuple ``(X, y)`` where ``X`` is a matrix of shape ``(n_samples, n_features)`` and ``y`` is a vector of shape ``(n_samples,)``. Returns: objective value. Example:: value = ridge_regression(W, l2reg, (X, y)) """ least_squares = LeastSquares()(params, data) ridge = 0.5 * l2reg * jnp.sum(params ** 2) return least_squares + ridge
_logloss_vmap = jax.vmap(loss.multiclass_logistic_loss) class MulticlassLogreg(CompositeLinearFunction): """Multiclass logistic regression objective.""" def subfun(self, predictions, data): y = data[1] return jnp.mean(_logloss_vmap(y, predictions)) def make_linop(self, data): """Creates linear operator.""" return base.LinearOperator(data[0]) def columnwise_lipschitz_const(self, data): """Column-wise Lipschitz constants.""" linop = self.make_linop(data) return linop.column_l2_norms(squared=True) * 0.5 def __call__(self, W, data): r"""Multiclass logistic regression. .. math:: \frac{1}{n} \sum_{i=1}^n \ell(W^\top x_i, y_i) where :math:`\ell` is :func:`multiclass_logistic_loss <jaxopt.loss.multiclass_logistic_loss>` and ``X, y = data``. Args: W: a matrix of shape ``(n_features, n_classes)``. data: a tuple ``(X, y)`` where ``X`` is a matrix of shape ``(n_samples, n_features)`` and ``y`` is a vector of shape ``(n_samples,)``. Returns: objective value. Example:: value = multiclass_logreg(W, (X, y)) """ return super().__call__(W, data) multiclass_logreg = MulticlassLogreg() multiclass_logreg.__doc__ = multiclass_logreg.__call__.__doc__
[docs]def multiclass_logreg_with_intercept( params: Tuple[jnp.ndarray, jnp.ndarray], data: Tuple[jnp.ndarray, jnp.ndarray]) -> float: r""" Multiclass logistic regression with intercept. .. math:: \frac{1}{n} \sum_{i=1}^n \ell(W^\top x_i + b, y_i) where :math:`\ell` is :func:`multiclass_logistic_loss <jaxopt.loss.multiclass_logistic_loss>`, ``W, b = params`` and ``X, y = data``. Args: params: a tuple ``(W, b)``, where ``W`` is a matrix of shape ``(n_features, n_classes)`` and ``b`` is a vector of shape ``(n_classes,)``. data: a tuple ``(X, y)`` where ``X`` is a matrix of shape ``(n_samples, n_features)`` and ``y`` is a vector of shape ``(n_samples,)``. Returns: objective value. """ X, y = data W, b = params y_pred = jnp.dot(X, W) + b return jnp.mean(_logloss_vmap(y, y_pred))
[docs]def l2_multiclass_logreg(W: jnp.ndarray, l2reg: float, data: Tuple[jnp.ndarray, jnp.ndarray]) -> float: r""" L2-regularized multiclass logistic regression. .. math:: \frac{1}{n} \sum_{i=1}^n \ell(W^\top x_i, y_i) + 0.5 \cdot \text{l2reg} \cdot ||W||_2^2 where :math:`\ell` is :func:`multiclass_logistic_loss <jaxopt.loss.multiclass_logistic_loss>` and ``X, y = data``. Args: W: a matrix of shape ``(n_features, n_classes)``. data: a tuple ``(X, y)`` where ``X`` is a matrix of shape ``(n_samples, n_features)`` and ``y`` is a vector of shape ``(n_samples,)``. Returns: objective value. """ X, y = data y_pred = jnp.dot(X, W) return jnp.mean(_logloss_vmap(y, y_pred)) + 0.5 * l2reg * jnp.sum(W ** 2)
[docs]def l2_multiclass_logreg_with_intercept( params: Tuple[jnp.ndarray, jnp.ndarray], l2reg: float, data: Tuple[jnp.ndarray, jnp.ndarray]) -> float: r""" L2-regularized multiclass logistic regression with intercept. .. math:: \frac{1}{n} \sum_{i=1}^n \ell(W^\top x_i + b, y_i) + 0.5 \cdot \text{l2reg} \cdot ||W||_2^2 where :math:`\ell` is :func:`multiclass_logistic_loss <jaxopt.loss.multiclass_logistic_loss>`, ``W, b = params`` and ``X, y = data``. Args: params: a tuple ``(W, b)``, where ``W`` is a matrix of shape ``(n_features, n_classes)`` and ``b`` is a vector of shape ``(n_classes,)``. data: a tuple ``(X, y)`` where ``X`` is a matrix of shape ``(n_samples, n_features)`` and ``y`` is a vector of shape ``(n_samples,)``. Returns: objective value. """ X, y = data W, b = params y_pred = jnp.dot(X, W) + b return jnp.mean(_logloss_vmap(y, y_pred)) + 0.5 * l2reg * jnp.sum(W ** 2)
_binary_logloss_vmap = jax.vmap(loss.binary_logistic_loss) class BinaryLogreg(CompositeLinearFunction): """Binary logistic regression objective.""" def subfun(self, predictions, data): y = data[1] return jnp.mean(_binary_logloss_vmap(y, predictions)) def make_linop(self, data): """Creates linear operator.""" return base.LinearOperator(data[0]) def columnwise_lipschitz_const(self, data): """Column-wise Lipschitz constants.""" linop = self.make_linop(data) return linop.column_l2_norms(squared=True) * 0.25 def __call__(self, w, data): r"""Binary logistic regression. .. math:: \frac{1}{n} \sum_{i=1}^n \ell(w^\top x_i, y_i) where :math:`\ell` is :func:`binary_logistic_loss <jaxopt.loss.binary_logistic_loss>` and ``X, y = data``. Args: w: a vector of shape ``(n_features, )``. data: a tuple ``(X, y)`` where ``X`` is a matrix of shape ``(n_samples, n_features)`` and ``y`` is a vector of shape ``(n_samples,)``, containing ``0`` or ``1`` values. Returns: objective value. Example:: value = binary_logreg(w, (X, y)) """ return super().__call__(w, data) binary_logreg = BinaryLogreg() binary_logreg.__doc__ = binary_logreg.__call__.__doc__ class MulticlassLinearSvmDual(CompositeLinearFunction): """Dual objective function of multiclass linear SVMs.""" def subfun(self, Xbeta, l2reg, data): X, Y = data XY = jnp.dot(X.T, Y) # todo: avoid storing / computing this matrix. # The dual objective is: # fun(beta) = vdot(beta, 1 - Y) - 0.5 / l2reg * ||V(beta)||^2 # where V(beta) = dot(X.T, Y) - dot(X.T, beta). V = XY - Xbeta # With opposite sign, as we want to maximize. return 0.5 / l2reg * jnp.vdot(V, V) def make_linop(self, l2reg, data): """Creates linear operator.""" return base.LinearOperator(data[0].T) def columnwise_lipschitz_const(self, l2reg, data): """Column-wise Lipschitz constants.""" linop = self.make_linop(l2reg, data) return linop.column_l2_norms(squared=True) def b(self, l2reg, data): return data[1] - 1 multiclass_linear_svm_dual = MulticlassLinearSvmDual()