Loss and objective functions

Loss functions

Regression

jaxopt.loss.huber_loss(target, pred[, delta])

Huber loss.

Regression losses are of the form loss(float: target, float: pred) -> float, where target is the ground-truth and pred is the model’s output.

Binary classification

jaxopt.loss.binary_logistic_loss(label, logit)

Binary logistic loss.

jaxopt.loss.binary_sparsemax_loss(label, logit)

Binary sparsemax loss.

jaxopt.loss.binary_hinge_loss(label, score)

Binary hinge loss.

jaxopt.loss.binary_perceptron_loss(label, score)

Binary perceptron loss.

Binary classification losses are of the form loss(int: label, float: score) -> float, where label is the ground-truth (0 or 1) and score is the model’s output.

The following utility functions are useful for the binary sparsemax loss.

jaxopt.loss.sparse_plus(x)

Sparse plus function.

jaxopt.loss.sparse_sigmoid(x)

Sparse sigmoid function.

Multiclass classification

jaxopt.loss.multiclass_logistic_loss(label, ...)

Multiclass logistic loss.

jaxopt.loss.multiclass_sparsemax_loss(label, ...)

Multiclass sparsemax loss.

jaxopt.loss.multiclass_hinge_loss(label, scores)

Multiclass hinge loss.

jaxopt.loss.multiclass_perceptron_loss(...)

Binary perceptron loss.

Multiclass classification losses are of the form loss(int: label, jnp.ndarray: scores) -> float, where label is the ground-truth (between 0 and n_classes - 1) and scores is an array of size n_classes.

Applying loss functions on a batch

All loss functions above are pointwise, meaning that they operate on a single sample. Use jax.vmap(loss) followed by a reduction such as jnp.mean or jnp.sum to use on a batch.

Objective functions

Composite linear functions

jaxopt.objective.least_squares

Least squares.

jaxopt.objective.binary_logreg

Binary logistic regression.

jaxopt.objective.multiclass_logreg

Multiclass logistic regression.

jaxopt.objective.multiclass_linear_svm_dual

Dual objective function of multiclass linear SVMs.

Composite linear objective functions can be used with block coordinate descent.

Other functions

jaxopt.objective.ridge_regression(params, ...)

Ridge regression, i.e L2-regularized least squares.

jaxopt.objective.multiclass_logreg_with_intercept(...)

Multiclass logistic regression with intercept.

jaxopt.objective.l2_multiclass_logreg(W, ...)

L2-regularized multiclass logistic regression.

jaxopt.objective.l2_multiclass_logreg_with_intercept(...)

L2-regularized multiclass logistic regression with intercept.