jaxopt.loss.multiclass_logistic_loss

jaxopt.loss.multiclass_logistic_loss(label, logits)[source]

Multiclass logistic loss.

Parameters
  • label (int) – ground-truth integer label, between 0 and n_classes - 1.

  • logits (Array) – scores produced by the model, shape = (n_classes, ).

Return type

float

Returns

loss value