from typing import List, Optional, Tuple, Union from torchtyping import TensorType from transformers.adapters.modeling import Adapter from transformers.adapters import ( BartAdapterModel, RobertaAdapterModel, BertAdapterModel, AdapterConfig, ) import torch from torch import nn from geomloss import SamplesLoss class AlignmentMixin(nn.Module): def __init__(self, config): config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 super().__init__(config) self.earth_mover_loss = SamplesLoss(loss="sinkhorn", p=2) @torch.no_grad() def produce_original_embeddings( self, input_ids: TensorType["batch", "seq_len"], attention_mask: TensorType["batch", "seq_len"], token_type_ids: Optional[TensorType["batch", "seq_len"]] = None, position_ids: Optional[TensorType["batch", "seq_len"]] = None, head_mask: Optional[TensorType["layers", "heads"]] = None, ) -> TensorType["batch", "seq_len", "hidden_size"]: self.train(False) outputs = super().forward( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, output_attentions=False, output_hidden_states=True, return_dict=True, ) if "last_hidden_state" in outputs: hidden_mat = outputs.last_hidden_state else: hidden_mat = outputs.encoder_last_hidden_state self.train(True) return outputs.last_hidden_state, attention_mask def get_weight(self, mask): probs = mask / mask.sum(1).reshape(-1, 1) return probs def forward( self, input_ids: TensorType["batch", "seq_len"], attention_mask: TensorType["batch", "seq_len"], original_embedding: Optional[ TensorType["batch", "layers", "hidden_size"] ] = None, original_mask: TensorType["batch", "seq_len"], token_type_ids: Optional[TensorType["batch", "seq_len"]] = None, position_ids: Optional[TensorType["batch", "seq_len"]] = None, head_mask: Optional[TensorType["layers", "heads"]] = None, **kwargs ): if type(original_embedding) != type(None): outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, output_attentions=False, output_hidden_states=True, return_dict=True, ) if "last_hidden_state" in outputs: hidden_mat = outputs.last_hidden_state else: hidden_mat = outputs.encoder_last_hidden_state alignment_loss = self.earth_mover_loss( self.get_weight(attention_mask), hidden_mat, self.get_weight(original_mask), original_embeddings ) return (alignment_loss,) class BartAdapterModelForAlignment(AlignmentMixin, BartAdapterModel): def __init__(self, config): config.dropout = 0.0 config.activation_dropout = 0.0 config.attention_dropout = 0.0 config.classifier_dropout = 0.0 super().__init__(config) class RobertaAdapterModelForAlignment(AlignmentMixin, RobertaAdapterModel): def __init__(self, config): config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 super().__init__(config) class BertAdapterModelForAlignment(AlignmentMixin, BertAdapterModel): def __init__(self, config): config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 super().__init__(config)