jaxopt.objective.l2_multiclass_logreg

jaxopt.objective.l2_multiclass_logreg(W, l2reg, data)[source]

L2-regularized multiclass logistic regression.

\[\frac{1}{n} \sum_{i=1}^n \ell(W^\top x_i, y_i) + 0.5 \cdot \text{l2reg} \cdot ||W||_2^2\]

where \(\ell\) is multiclass_logistic_loss and X, y = data.

Parameters
  • W (Array) – a matrix of shape (n_features, n_classes).

  • data (Tuple[Array, Array]) – a tuple (X, y) where X is a matrix of shape (n_samples, n_features) and y is a vector of shape (n_samples,).

  • l2reg (float) –

Return type

float

Returns

objective value.