import torch from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss def distillation_loss( teacher_logits : Tensor, student_logits : Tensor, labels : Tensor, temperature : float = 1.0, ) -> Tensor: """ The distillation loss for distilating a BERT-like model. The loss takes the (teacher_logits), (student_logits) and (labels) for various losses. The (temperature) can be given, otherwise it's set to 1 by default. """ # Temperature and sotfmax student_logits, teacher_logits = (student_logits / temperature).softmax(1), (teacher_logits / temperature).softmax(1) # Classification loss (problem-specific loss) loss = CrossEntropyLoss()(student_logits, labels) # CrossEntropy teacher-student loss loss = loss + CrossEntropyLoss()(student_logits, teacher_logits) # Cosine loss loss = loss + CosineEmbeddingLoss()(teacher_logits, student_logits, torch.ones(teacher_logits.size()[0])) # Average the loss and return it loss = loss / 3 return loss