Skip to content

Instantly share code, notes, and snippets.

@manueldeprada
Last active July 27, 2023 10:34
Show Gist options
  • Select an option

  • Save manueldeprada/00d7a84632d8e858ff0c208e5e44559b to your computer and use it in GitHub Desktop.

Select an option

Save manueldeprada/00d7a84632d8e858ff0c208e5e44559b to your computer and use it in GitHub Desktop.
A Toy Transformer for Debugging Generation Algorithms in HuggingFace🤗
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import Seq2SeqLMOutput
import torch
class FakeTransformerConfig(PretrainedConfig):
model_type = "FakeTransformer"
def __init__(self, vocab_size=4, **kwargs):
super().__init__(pad_token_id=-1, eos_token_id=3, bos_token_id=0, **kwargs)
self.vocab_size = vocab_size
class FakeTransformer(PreTrainedModel):
config_class = FakeTransformerConfig
def __init__(self, config):
config = FakeTransformerConfig() if config is None else config
super().__init__(config)
self.fake_param = torch.nn.Parameter(torch.tensor(1.0)) # need at least one parameter to be a valid model
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
past_key_values=None,
return_dict=True,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
seq_len = input_ids.shape[1]
batch_size = input_ids.shape[0]
# Placeholder for output probabilities
output = torch.zeros(batch_size, seq_len, self.config.vocab_size)
if seq_len >= 1:
output[:, 0, 1] = 0.75
output[:, 0, 2] = 0.25
if seq_len >= 2:
output[:, 1, 1] = 0.10
output[:, 1, 2] = 0.90
if seq_len >= 3:
output[:, 2, :] = 0.0
output[:, 2, 3] = 1.0
if not past_key_values: # when using past_key_values, only last token logits are generated
output = output[:, -1, :].unsqueeze(1)
# Placeholder for past_key_values, attentions, hidden_states
past_key_values = () if past_key_values is None else past_key_values
attentions = () if output_attentions else None
hidden_states = () if output_hidden_states else None
# Create Seq2SeqLMOutput object
if not return_dict:
att, hidd = attentions if attentions is not None else (), hidden_states if hidden_states is not None else ()
return (output,) + past_key_values + hidd + att
else:
return Seq2SeqLMOutput(
loss=None,
logits=output.log(),
past_key_values=past_key_values,
decoder_hidden_states=hidden_states,
decoder_attentions=attentions,
encoder_last_hidden_state=None,
encoder_hidden_states=None,
encoder_attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def main():
config = FakeTransformerConfig()
model = FakeTransformer(config)
input_ids = torch.tensor([[0]] * 10)
output = model.generate(input_ids, max_length=4, do_sample=True, return_dict_in_generate=True, output_scores=True)
print(f"generated tokens: {output.sequences}")
print(f"probs at step 1: {torch.exp(output.scores[0][0,:])}")
print(f"probs at step 2: {torch.exp(output.scores[1][0,:])}")
print(f"probs at step 3: {torch.exp(output.scores[2][0,:])}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment