Loss and objective functions
Loss functions
Regression
|
Huber loss. |
Regression losses are of the form loss(float: target, float: pred) -> float
,
where target
is the ground-truth and pred
is the model’s output.
Binary classification
|
Binary logistic loss. |
|
Binary sparsemax loss. |
|
Binary hinge loss. |
|
Binary perceptron loss. |
Binary classification losses are of the form loss(int: label, float: score) -> float
,
where label
is the ground-truth (0
or 1
) and score
is the model’s output.
The following utility functions are useful for the binary sparsemax loss.
Sparse plus function. |
|
Sparse sigmoid function. |
Multiclass classification
|
Multiclass logistic loss. |
|
Multiclass sparsemax loss. |
|
Multiclass hinge loss. |
Binary perceptron loss. |
Multiclass classification losses are of the form loss(int: label, jnp.ndarray: scores) -> float
,
where label
is the ground-truth (between 0
and n_classes - 1
) and
scores
is an array of size n_classes
.
Applying loss functions on a batch
All loss functions above are pointwise, meaning that they operate on a single sample. Use jax.vmap(loss)
followed by a reduction such as jnp.mean
or jnp.sum
to use on a batch.
Objective functions
Composite linear functions
Least squares. |
|
Binary logistic regression. |
|
Multiclass logistic regression. |
|
Dual objective function of multiclass linear SVMs. |
Composite linear objective functions can be used with block coordinate descent.
Other functions
|
Ridge regression, i.e L2-regularized least squares. |
Multiclass logistic regression with intercept. |
|
L2-regularized multiclass logistic regression. |
|
L2-regularized multiclass logistic regression with intercept. |