.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/constrained/multiclass_linear_svm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_constrained_multiclass_linear_svm.py: Multiclass linear SVM (without intercept). ========================================== This quadratic program can be solved either with OSQP or with block coordinate descent. Reference: Crammer, K. and Singer, Y., 2001. On the algorithmic implementation of multiclass kernel-based vector machines. Journal of machine learning research, 2(Dec), pp.265-292. .. GENERATED FROM PYTHON SOURCE LINES 26-144 .. code-block:: default from absl import app from absl import flags import jax import jax.numpy as jnp from jaxopt import BlockCoordinateDescent from jaxopt import OSQP from jaxopt import objective from jaxopt import projection from jaxopt import prox from sklearn import datasets from sklearn import preprocessing from sklearn import svm flags.DEFINE_float("tol", 1e-5, "Tolerance of solvers.") flags.DEFINE_float("l2reg", 1000., "Regularization parameter. Must be positive.") flags.DEFINE_integer("num_samples", 20, "Size of train set.") flags.DEFINE_integer("num_features", 5, "Features dimension.") flags.DEFINE_integer("num_classes", 3, "Number of classes.") flags.DEFINE_bool("verbose", False, "Verbosity.") FLAGS = flags.FLAGS def multiclass_linear_svm_skl(X, y, l2reg): print("Solve multiclass SVM with sklearn.svm.LinearSVC:") svc = svm.LinearSVC(loss="hinge", dual=True, multi_class="crammer_singer", C=1.0 / l2reg, fit_intercept=False, tol=FLAGS.tol, max_iter=100*1000).fit(X, y) return svc.coef_.T def multiclass_linear_svm_bcd(X, Y, l2reg): print("Block coordinate descent solution:") # Set up parameters. block_prox = prox.make_prox_from_projection(projection.projection_simplex) fun = objective.multiclass_linear_svm_dual data = (X, Y) beta_init = jnp.ones((X.shape[0], Y.shape[-1])) / Y.shape[-1] # Run solver. bcd = BlockCoordinateDescent(fun=fun, block_prox=block_prox, maxiter=10*1000, tol=FLAGS.tol) sol = bcd.run(beta_init, hyperparams_prox=None, l2reg=FLAGS.l2reg, data=data) return sol.params def multiclass_linear_svm_osqp(X, Y, l2reg): # We solve the problem # # minimize 0.5/l2reg beta X X.T beta - (1. - Y)^T beta - 1./l2reg (Y^T X) X^T beta # under beta >= 0 # sum_i beta_i = 1 # print("OSQP solution solution:") def matvec_Q(X, beta): return 1./l2reg * jnp.dot(X, jnp.dot(X.T, beta)) linear_part = - (1. - Y) - 1./l2reg * jnp.dot(X, jnp.dot(X.T, Y)) def matvec_A(_, beta): return jnp.sum(beta, axis=-1) def matvec_G(_, beta): return -beta b = jnp.ones(X.shape[0]) h = jnp.zeros_like(Y) osqp = OSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, matvec_G=matvec_G, tol=FLAGS.tol, maxiter=10*1000) hyper_params = dict(params_obj=(X, linear_part), params_eq=(None, b), params_ineq=(None, h)) sol, _ = osqp.run(init_params=None, **hyper_params) return sol.primal def main(argv): del argv # Generate data. num_samples = FLAGS.num_samples num_features = FLAGS.num_features num_classes = FLAGS.num_classes X, y = datasets.make_classification(n_samples=num_samples, n_features=num_features, n_informative=3, n_classes=num_classes, random_state=0) X = preprocessing.Normalizer().fit_transform(X) Y = preprocessing.LabelBinarizer().fit_transform(y) Y = jnp.array(Y) l2reg = FLAGS.l2reg # Compare against sklearn. W_osqp = multiclass_linear_svm_osqp(X, Y, l2reg) W_fit_osqp = jnp.dot(X.T, (Y - W_osqp)) / l2reg print(W_fit_osqp) print() W_bcd = multiclass_linear_svm_bcd(X, Y, l2reg) W_fit_bcd = jnp.dot(X.T, (Y - W_bcd)) / l2reg print(W_fit_bcd) print() W_skl = multiclass_linear_svm_skl(X, y, l2reg) print(W_skl) print() if __name__ == "__main__": jax.config.update("jax_platform_name", "cpu") app.run(main) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_constrained_multiclass_linear_svm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: multiclass_linear_svm.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: multiclass_linear_svm.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_