Skip to content

Instantly share code, notes, and snippets.

@remi-or
Created January 18, 2022 17:55
Show Gist options
  • Select an option

  • Save remi-or/c8a7b15d1f6d97202e1d8d965950fcd3 to your computer and use it in GitHub Desktop.

Select an option

Save remi-or/c8a7b15d1f6d97202e1d8d965950fcd3 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss
def distillation_loss(
teacher_logits : Tensor,
student_logits : Tensor,
labels : Tensor,
temperature : float = 1.0,
) -> Tensor:
"""
The distillation loss for distilating a BERT-like model.
The loss takes the (teacher_logits), (student_logits) and (labels) for various losses.
The (temperature) can be given, otherwise it's set to 1 by default.
"""
# Temperature and sotfmax
student_logits, teacher_logits = (student_logits / temperature).softmax(1), (teacher_logits / temperature).softmax(1)
# Classification loss (problem-specific loss)
loss = CrossEntropyLoss()(student_logits, labels)
# CrossEntropy teacher-student loss
loss = loss + CrossEntropyLoss()(student_logits, teacher_logits)
# Cosine loss
loss = loss + CosineEmbeddingLoss()(teacher_logits, student_logits, torch.ones(teacher_logits.size()[0]))
# Average the loss and return it
loss = loss / 3
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment