Skip to content

Instantly share code, notes, and snippets.

@lmassaron
Created April 29, 2022 21:33
Show Gist options
  • Save lmassaron/6ca33ee6df8db8d8b7ef78c7d7241e06 to your computer and use it in GitHub Desktop.
Save lmassaron/6ca33ee6df8db8d8b7ef78c7d7241e06 to your computer and use it in GitHub Desktop.
def poly1_cross_entropy(logits, labels, epsilon=1.0):
# pt, CE, and Poly1 have shape [batch].
pt = tf.reduce_sum(labels * tf.nn.softmax(logits), axis=-1)
CE = tf.nn.softmax_cross_entropy_with_logits(labels, logits)
Poly1 = CE + epsilon * (1 - pt)
return Poly1
def poly1_focal_loss(logits, labels, epsilon=1.0, gamma=2.0):
# p, pt, FL, and Poly1 have shape [batch, num of classes].
p = tf.math.sigmoid(logits)
pt = labels * p + (1 - labels) * (1 - p)
FL = focal_loss(pt, gamma)
Poly1 = FL + epsilon * tf.math.pow(1 - pt, gamma + 1)
return Poly1
# https://arxiv.org/pdf/2204.12511.pdf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment