Binary kernel SVM with intercept.

The dual objective of binary kernel SVMs with an intercept contains both box constraints and an equality constraint, making it challenging to solve. The state-of-the-art algorithm to solve this objective is SMO (Sequential minimal optimization).

We nevertheless demonstrate in this example how to solve it by projected gradient descent, by projecting on the constraint set using projection_box_section.

Since the dual objective is a Quadratic Program we show how to solve it with BoxOSQP too.

from absl import app
from absl import flags

import jax
import jax.numpy as jnp
from jaxopt import projection
from jaxopt import ProjectedGradient
from jaxopt import BoxOSQP

import numpy as onp
from sklearn import datasets
from sklearn import preprocessing
from sklearn import svm


flags.DEFINE_float("lam", 0.5, "Regularization parameter. Must be positive.")
flags.DEFINE_float("tol", 1e-6, "Tolerance of solvers.")
flags.DEFINE_integer("num_samples", 30, "Size of train set.")
flags.DEFINE_integer("num_features", 5, "Features dimension.")
flags.DEFINE_bool("verbose", False, "Verbosity.")
FLAGS = flags.FLAGS


def binary_kernel_svm_skl(X, y, C):
  print(f"Solve SVM with sklearn.svm.SVC: ")
  K = jnp.dot(X, X.T)
  svc = svm.SVC(kernel="precomputed", C=C, tol=FLAGS.tol).fit(K, y)
  dual_coef = onp.zeros(K.shape[0])
  dual_coef[svc.support_] = svc.dual_coef_[0]
  return dual_coef


def binary_kernel_svm_pg(X, y, C):
  print(f"Solve SVM with Projected Gradient: ")

  def objective_fun(beta, X, y):
    """Dual objective of binary kernel SVMs with intercept."""
    # The dual objective is:
    # fun(beta) = 0.5 beta^T K beta - beta^T y
    # subject to
    # sum(beta) = 0
    # 0 <= beta_i <= C if y_i = 1
    # -C <= beta_i <= 0 if y_i = -1
    # where C = 1.0 / lam
    # and K = X X^T
    Kbeta = jnp.dot(X, jnp.dot(X.T, beta))
    return 0.5 * jnp.dot(beta, Kbeta) - jnp.dot(beta, y)

  # Define projection operator.
  w = jnp.ones(X.shape[0])

  def proj(beta, C):
    box_lower = jnp.where(y == 1, 0, -C)
    box_upper = jnp.where(y == 1, C, 0)
    proj_params = (box_lower, box_upper, w, 0.0)
    return projection.projection_box_section(beta, proj_params)

  # Run solver.
  beta_init = jnp.ones(X.shape[0])
  solver = ProjectedGradient(fun=objective_fun,
                             projection=proj,
                             tol=FLAGS.tol, maxiter=500, verbose=FLAGS.verbose)
  beta_fit = solver.run(beta_init, hyperparams_proj=C, X=X, y=y).params

  return beta_fit


def binary_kernel_svm_osqp(X, y, C):
  # The dual objective is:
  # fun(beta) = 0.5 beta^T K beta - beta^T y
  # subject to
  # sum(beta) = 0
  # 0 <= beta_i <= C if y_i = 1
  # -C <= beta_i <= 0 if y_i = -1
  # where C = 1.0 / lam

  print(f"Solve SVM with OSQP: ")

  def matvec_Q(X, beta):
    return jnp.dot(X, jnp.dot(X.T,  beta))

  # There qre two types of constraints:
  #   0 <= y_i * beta_i <= C     (1)
  # and:
  #   sum(beta) = 0              (2)
  # The first one involves the identity matrix over the betas.
  # The second one involves their sum (i.e dot product with vector full of 1).
  # We take advantage of matvecs to avoid materializing A in memory.
  # We return a tuple whose entries correspond each type of constraint.
  def matvec_A(_, beta):
    return beta, jnp.sum(beta)

  # l, u must have same shape than matvec_A's output.
  l = -jax.nn.relu(-y * C), 0.
  u =  jax.nn.relu( y * C), 0.

  hyper_params = dict(params_obj=(X, -y), params_eq=None, params_ineq=(l, u))
  osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=FLAGS.tol)
  params, _ = osqp.run(init_params=None, **hyper_params)
  beta = params.primal[0]

  return beta


def print_svm_result(beta, threshold=1e-4):
  # Here the vector `beta` of coefficients is signed:
  # its sign depends of the true label of the corresponding example.
  # Hence we use jnp.abs() to detect support vectors.
  is_support_vectors = jnp.abs(beta) > threshold
  print(f"Beta: {beta}")
  print(f"Support vector indices: {onp.where(is_support_vectors)[0]}")
  print("")


def main(argv):
  del argv

  num_samples = FLAGS.num_samples
  num_features = FLAGS.num_features

  # Prepare data.
  X, y = datasets.make_classification(n_samples=num_samples, n_features=num_features,
                                      n_classes=2,
                                      random_state=0)
  X = preprocessing.Normalizer().fit_transform(X)
  y = jnp.array(y * 2. - 1)  # Transform labels from {0, 1} to {-1., 1.}.

  lam = FLAGS.lam
  C = 1./ lam

  # Compare the obtained dual coefficients.
  beta_fit_osqp = binary_kernel_svm_osqp(X, y, C)
  print_svm_result(beta_fit_osqp)

  beta_fit_pg = binary_kernel_svm_pg(X, y, C)
  print_svm_result(beta_fit_pg)

  beta_fit_skl = binary_kernel_svm_skl(X, y, C)
  print_svm_result(beta_fit_skl)


if __name__ == "__main__":
  jax.config.update("jax_platform_name", "cpu")
  app.run(main)

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery