jaxopt.objective.multiclass_logreg_with_intercept

jaxopt.objective.multiclass_logreg_with_intercept(params, data)[source]

Multiclass logistic regression with intercept.

\[\frac{1}{n} \sum_{i=1}^n \ell(W^\top x_i + b, y_i)\]

where \(\ell\) is multiclass_logistic_loss, W, b = params and X, y = data.

Parameters
  • params (Tuple[Array, Array]) – a tuple (W, b), where W is a matrix of shape (n_features, n_classes) and b is a vector of shape (n_classes,).

  • data (Tuple[Array, Array]) – a tuple (X, y) where X is a matrix of shape (n_samples, n_features) and y is a vector of shape (n_samples,).

Return type

float

Returns

objective value.