Skip to content

Instantly share code, notes, and snippets.

View visheratin's full-sized avatar

Alexander Visheratin visheratin

View GitHub Profile
import torch.nn as nn
from transformers import AutoModel
class SiameseNet(nn.Module):
def __init__(self):
super(SiameseNet, self).__init__()
# DistilBERT initialization
self.tf_layer = AutoModel.from_pretrained('/jupyter/models/distilbert')
# uncomment these lines for freezing DistilBERT weights
# for p in self.tf_layer.parameters():
import torch
import torch.nn.functional as F
class TripletLoss(torch.nn.Module):
def __init__(self, margin=2.0):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, negative, positive):
neg_dist = F.pairwise_distance(anchor, negative, keepdim = True)