Loss and objective functions

Loss functions


jaxopt.loss.huber_loss(target, pred[, delta])

Huber loss.

Regression losses are of the form loss(float: target, float: pred) -> float, where target is the ground-truth and pred is the model’s output.

Binary classification

jaxopt.loss.binary_logistic_loss(label, logit)

Binary logistic loss.

jaxopt.loss.binary_sparsemax_loss(label, logit)

Binary sparsemax loss.

jaxopt.loss.binary_hinge_loss(label, score)

Binary hinge loss.

jaxopt.loss.binary_perceptron_loss(label, score)

Binary perceptron loss.

Binary classification losses are of the form loss(int: label, float: score) -> float, where label is the ground-truth (0 or 1) and score is the model’s output.

The following utility functions are useful for the binary sparsemax loss.


Sparse plus function.


Sparse sigmoid function.

Multiclass classification

jaxopt.loss.multiclass_logistic_loss(label, ...)

Multiclass logistic loss.

jaxopt.loss.multiclass_sparsemax_loss(label, ...)

Multiclass sparsemax loss.

jaxopt.loss.multiclass_hinge_loss(label, scores)

Multiclass hinge loss.


Binary perceptron loss.

Multiclass classification losses are of the form loss(int: label, jnp.ndarray: scores) -> float, where label is the ground-truth (between 0 and n_classes - 1) and scores is an array of size n_classes.

Applying loss functions on a batch

All loss functions above are pointwise, meaning that they operate on a single sample. Use jax.vmap(loss) followed by a reduction such as jnp.mean or jnp.sum to use on a batch.

Objective functions

Composite linear functions


Least squares.


Binary logistic regression.


Multiclass logistic regression.


Dual objective function of multiclass linear SVMs.

Composite linear objective functions can be used with block coordinate descent.

Other functions

jaxopt.objective.ridge_regression(params, ...)

Ridge regression, i.e L2-regularized least squares.


Multiclass logistic regression with intercept.

jaxopt.objective.l2_multiclass_logreg(W, ...)

L2-regularized multiclass logistic regression.


L2-regularized multiclass logistic regression with intercept.