Created
January 18, 2022 17:55
-
-
Save remi-or/c8a7b15d1f6d97202e1d8d965950fcd3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss | |
| def distillation_loss( | |
| teacher_logits : Tensor, | |
| student_logits : Tensor, | |
| labels : Tensor, | |
| temperature : float = 1.0, | |
| ) -> Tensor: | |
| """ | |
| The distillation loss for distilating a BERT-like model. | |
| The loss takes the (teacher_logits), (student_logits) and (labels) for various losses. | |
| The (temperature) can be given, otherwise it's set to 1 by default. | |
| """ | |
| # Temperature and sotfmax | |
| student_logits, teacher_logits = (student_logits / temperature).softmax(1), (teacher_logits / temperature).softmax(1) | |
| # Classification loss (problem-specific loss) | |
| loss = CrossEntropyLoss()(student_logits, labels) | |
| # CrossEntropy teacher-student loss | |
| loss = loss + CrossEntropyLoss()(student_logits, teacher_logits) | |
| # Cosine loss | |
| loss = loss + CosineEmbeddingLoss()(teacher_logits, student_logits, torch.ones(teacher_logits.size()[0])) | |
| # Average the loss and return it | |
| loss = loss / 3 | |
| return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment