Source code for jaxopt._src.isotonic

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Isotonic Regression."""

import warnings
import numpy as onp
import jax
import jax.numpy as jnp


# pylint: disable=g-import-not-at-top
try:
  from numba import njit

  NUMBA_AVAILABLE = True
except ImportError:
  NUMBA_AVAILABLE = False
  # If Numba is not available, we define a dummy 'njit' function.

  def njit(func):
    return func


@njit
def _isotonic_l2_pav_numba(y):
  n = y.shape[0]
  target = onp.arange(n)
  c = onp.ones(n)
  sums = onp.zeros(n)
  sol = onp.zeros(n)

  # target describes a list of blocks.  At any time, if [i..j] (inclusive) is
  # an active block, then target[i] := j and target[j] := i.

  for i in range(n):
    sol[i] = y[i]
    sums[i] = y[i]

  i = 0
  while i < n:
    k = target[i] + 1
    if k == n:
      break
    if sol[i] > sol[k]:
      i = k
      continue
    sum_y = sums[i]
    sum_c = c[i]
    while True:
      # We are within an increasing subsequence.
      prev_y = sol[k]
      sum_y += sums[k]
      sum_c += c[k]
      k = target[k] + 1
      if k == n or prev_y > sol[k]:
        # Non-singleton increasing subsequence is finished,
        # update first entry.
        sol[i] = sum_y / sum_c
        sums[i] = sum_y
        c[i] = sum_c
        target[i] = k - 1
        target[k - 1] = i
        if i > 0:
          # Backtrack if we can.  This makes the algorithm
          # single-pass and ensures O(n) complexity.
          i = target[i - 1]
        # Otherwise, restart from the same point.
        break

  # Reconstruct the solution.
  i = 0
  while i < n:
    k = target[i] + 1
    sol[i + 1 : k] = sol[i]
    i = k
  return sol.astype(y.dtype)


@jax.custom_jvp
def _isotonic_l2_pav(y):
  if not NUMBA_AVAILABLE:
    warnings.warn(
        "Numba could not be imported. Code will run much more slowly."
        " To install, run 'pip install numba'."
    )
  # Define the expected shape & dtype of output.
  shape_dtype = jax.ShapeDtypeStruct(shape=y.shape, dtype=y.dtype)
  sol = jax.pure_callback(_isotonic_l2_pav_numba, shape_dtype, y)
  return sol

[docs]def isotonic_l2_pav(y, y_min=-jnp.inf, y_max=jnp.inf, increasing=True): r"""Solves an isotonic regression problem using PAV. Args: y: input to isotonic regression, a 1d-array. y_min : Lower bound on the lowest predicted value. y_max : Upper bound on the highest predicted value increasing : Order of the constraints: If True, it solves :math:`\mathop{\mathrm{arg\,min}}_{v_1 \leq ... \leq v_n} \|v - y\|^2`. If False, it solves :math:`\mathop{\mathrm{arg\,min}}_{v_1 \geq ... \geq v_n} \|v - y\|^2`. Returns: The solution, an array of the same size as y. """ sign = -1 if increasing else 1 sol = _isotonic_l2_pav(y * sign) * sign sol = jnp.clip(sol, y_min, y_max) return sol
def _jvp_isotonic_l2_jax_pav(solution, vector, eps=1e-8): x = solution mask = jnp.pad(jnp.absolute(jnp.diff(x)) <= eps, (1, 0)) ar = jnp.arange(x.size) inds_start = jnp.where(mask == 0, ar, +jnp.inf).sort() one_hot_start = jax.nn.one_hot(inds_start, len(vector)) A = jnp.cumsum(one_hot_start, axis=-1) A = jnp.append(jnp.diff(A[::-1], axis=0)[::-1], A[-1].reshape(1, -1), axis=0) B = A.copy() return (((B.T * (B @ vector)).T) / (A.sum(1, keepdims=True) + 1e-8)).sum(0) @_isotonic_l2_pav.defjvp def _isotonic_l2_pav_jvp(primals, tangents): """Jacobian-vector product of isotonic_l2_pav. See Section 5 of Fast Differentiable Sorting and Ranking Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga ICML 2020 arXiv:2002.08871 """ (y, ) = primals (vector, ) = tangents primal_out = _isotonic_l2_pav(y) tangent_out = _jvp_isotonic_l2_jax_pav(primal_out, vector) return primal_out, tangent_out