jaxopt.objective.l2_multiclass_logreg_with_intercept

jaxopt.objective.l2_multiclass_logreg_with_intercept(params, l2reg, data)[source]

L2-regularized multiclass logistic regression with intercept.

\[\frac{1}{n} \sum_{i=1}^n \ell(W^\top x_i + b, y_i) + 0.5 \cdot \text{l2reg} \cdot ||W||_2^2\]

where \(\ell\) is multiclass_logistic_loss, W, b = params and X, y = data.

Parameters
  • params (Tuple[Array, Array]) – a tuple (W, b), where W is a matrix of shape (n_features, n_classes) and b is a vector of shape (n_classes,).

  • data (Tuple[Array, Array]) – a tuple (X, y) where X is a matrix of shape (n_samples, n_features) and y is a vector of shape (n_samples,).

  • l2reg (float) –

Return type

float

Returns

objective value.