Skip to content

Instantly share code, notes, and snippets.

@manueldeprada
Last active July 27, 2023 10:34
Show Gist options
  • Save manueldeprada/00d7a84632d8e858ff0c208e5e44559b to your computer and use it in GitHub Desktop.
Save manueldeprada/00d7a84632d8e858ff0c208e5e44559b to your computer and use it in GitHub Desktop.

Revisions

  1. manueldeprada revised this gist Jul 27, 2023. 1 changed file with 11 additions and 4 deletions.
    15 changes: 11 additions & 4 deletions fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    from transformers import PreTrainedModel, PretrainedConfig, PreTrainedTokenizer
    from transformers import PreTrainedModel, PretrainedConfig, PreTrainedTokenizer, BatchEncoding
    from transformers.modeling_outputs import Seq2SeqLMOutput
    import torch

    @@ -59,14 +59,18 @@ def forward(
    attentions = () if output_attentions else None
    hidden_states = () if output_hidden_states else None

    output = output.log()
    #replace -inf with -1e9
    #output[output == float('-inf')] = -1e9

    # Create Seq2SeqLMOutput object
    if not return_dict:
    att, hidd = attentions if attentions is not None else (), hidden_states if hidden_states is not None else ()
    return (output,) + past_key_values + hidd + att
    else:
    return Seq2SeqLMOutput(
    loss=None,
    logits=output.log(),
    logits=output,
    past_key_values=past_key_values,
    decoder_hidden_states=hidden_states,
    decoder_attentions=attentions,
    @@ -77,14 +81,17 @@ def forward(

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
    return {"input_ids": input_ids}


    def _reorder_cache(self, past, beam_idx):
    return past


    class FakeTokenizer(PreTrainedTokenizer):
    def __init__(self, **kwargs):
    super().__init__(**kwargs)

    def __call__(self, text, **kwargs):
    return {"input_ids": [0]}
    return BatchEncoding({"input_ids": torch.tensor([[0],[0]])})

    def batch_decode(self, token_ids, **kwargs):
    return [str(ids) for ids in token_ids]
  2. manueldeprada revised this gist Jul 17, 2023. 1 changed file with 12 additions and 1 deletion.
    13 changes: 12 additions & 1 deletion fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    from transformers import PreTrainedModel, PretrainedConfig
    from transformers import PreTrainedModel, PretrainedConfig, PreTrainedTokenizer
    from transformers.modeling_outputs import Seq2SeqLMOutput
    import torch

    @@ -77,6 +77,17 @@ def forward(

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
    return {"input_ids": input_ids}


    class FakeTokenizer(PreTrainedTokenizer):
    def __init__(self, **kwargs):
    super().__init__(**kwargs)

    def __call__(self, text, **kwargs):
    return {"input_ids": [0]}

    def batch_decode(self, token_ids, **kwargs):
    return [str(ids) for ids in token_ids]


    if __name__ == "__main__":
  3. manueldeprada revised this gist Jul 17, 2023. 1 changed file with 7 additions and 9 deletions.
    16 changes: 7 additions & 9 deletions fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -10,6 +10,8 @@ def __init__(self, vocab_size=4, **kwargs):
    super().__init__(pad_token_id=-1, eos_token_id=3, bos_token_id=0, **kwargs)
    self.vocab_size = vocab_size
    self.max_new_tokens = 4
    self.do_sample = True
    self.num_beams = 1


    class FakeTransformer(PreTrainedModel):
    @@ -76,18 +78,14 @@ def forward(
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
    return {"input_ids": input_ids}

    def main():

    if __name__ == "__main__":
    config = FakeTransformerConfig()
    model = FakeTransformer(config)

    input_ids = torch.tensor([[0]] * 10)
    output = model.generate(input_ids, max_length=4, do_sample=True, return_dict_in_generate=True, output_scores=True)
    print(f"generated tokens: {output.sequences}")
    print(f"probs at step 1: {torch.exp(output.scores[0][0,:])}")
    print(f"probs at step 2: {torch.exp(output.scores[1][0,:])}")
    print(f"probs at step 3: {torch.exp(output.scores[2][0,:])}")


    if __name__ == "__main__":
    main()
    print(f"probs at step 1: {torch.exp(output.scores[0][0, :])}")
    print(f"probs at step 2: {torch.exp(output.scores[1][0, :])}")
    print(f"probs at step 3: {torch.exp(output.scores[2][0, :])}")
  4. manueldeprada revised this gist Jul 17, 2023. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -9,6 +9,7 @@ class FakeTransformerConfig(PretrainedConfig):
    def __init__(self, vocab_size=4, **kwargs):
    super().__init__(pad_token_id=-1, eos_token_id=3, bos_token_id=0, **kwargs)
    self.vocab_size = vocab_size
    self.max_new_tokens = 4


    class FakeTransformer(PreTrainedModel):
  5. manueldeprada revised this gist Jul 17, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -14,7 +14,7 @@ def __init__(self, vocab_size=4, **kwargs):
    class FakeTransformer(PreTrainedModel):
    config_class = FakeTransformerConfig

    def __init__(self, config):
    def __init__(self, config=None):
    config = FakeTransformerConfig() if config is None else config
    super().__init__(config)
    self.fake_param = torch.nn.Parameter(torch.tensor(1.0)) # need at least one parameter to be a valid model
  6. manueldeprada revised this gist Jul 17, 2023. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -15,6 +15,7 @@ class FakeTransformer(PreTrainedModel):
    config_class = FakeTransformerConfig

    def __init__(self, config):
    config = FakeTransformerConfig() if config is None else config
    super().__init__(config)
    self.fake_param = torch.nn.Parameter(torch.tensor(1.0)) # need at least one parameter to be a valid model

  7. manueldeprada revised this gist Jul 17, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -75,7 +75,7 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs):
    return {"input_ids": input_ids}


    def main():
    def main():
    config = FakeTransformerConfig()
    model = FakeTransformer(config)

  8. manueldeprada revised this gist Jul 17, 2023. 1 changed file with 6 additions and 2 deletions.
    8 changes: 6 additions & 2 deletions fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -74,8 +74,8 @@ def forward(
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
    return {"input_ids": input_ids}


    if __name__ == "__main__":
    def main():
    config = FakeTransformerConfig()
    model = FakeTransformer(config)

    @@ -85,3 +85,7 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs):
    print(f"probs at step 1: {torch.exp(output.scores[0][0,:])}")
    print(f"probs at step 2: {torch.exp(output.scores[1][0,:])}")
    print(f"probs at step 3: {torch.exp(output.scores[2][0,:])}")


    if __name__ == "__main__":
    main()
  9. manueldeprada created this gist Jul 16, 2023.
    87 changes: 87 additions & 0 deletions fake_transformer.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,87 @@
    from transformers import PreTrainedModel, PretrainedConfig
    from transformers.modeling_outputs import Seq2SeqLMOutput
    import torch


    class FakeTransformerConfig(PretrainedConfig):
    model_type = "FakeTransformer"

    def __init__(self, vocab_size=4, **kwargs):
    super().__init__(pad_token_id=-1, eos_token_id=3, bos_token_id=0, **kwargs)
    self.vocab_size = vocab_size


    class FakeTransformer(PreTrainedModel):
    config_class = FakeTransformerConfig

    def __init__(self, config):
    super().__init__(config)
    self.fake_param = torch.nn.Parameter(torch.tensor(1.0)) # need at least one parameter to be a valid model

    def forward(
    self,
    input_ids=None,
    attention_mask=None,
    encoder_outputs=None,
    past_key_values=None,
    return_dict=True,
    output_attentions=None,
    output_hidden_states=None,
    **kwargs
    ):
    seq_len = input_ids.shape[1]
    batch_size = input_ids.shape[0]

    # Placeholder for output probabilities
    output = torch.zeros(batch_size, seq_len, self.config.vocab_size)

    if seq_len >= 1:
    output[:, 0, 1] = 0.75
    output[:, 0, 2] = 0.25

    if seq_len >= 2:
    output[:, 1, 1] = 0.10
    output[:, 1, 2] = 0.90

    if seq_len >= 3:
    output[:, 2, :] = 0.0
    output[:, 2, 3] = 1.0

    if not past_key_values: # when using past_key_values, only last token logits are generated
    output = output[:, -1, :].unsqueeze(1)

    # Placeholder for past_key_values, attentions, hidden_states
    past_key_values = () if past_key_values is None else past_key_values
    attentions = () if output_attentions else None
    hidden_states = () if output_hidden_states else None

    # Create Seq2SeqLMOutput object
    if not return_dict:
    att, hidd = attentions if attentions is not None else (), hidden_states if hidden_states is not None else ()
    return (output,) + past_key_values + hidd + att
    else:
    return Seq2SeqLMOutput(
    loss=None,
    logits=output.log(),
    past_key_values=past_key_values,
    decoder_hidden_states=hidden_states,
    decoder_attentions=attentions,
    encoder_last_hidden_state=None,
    encoder_hidden_states=None,
    encoder_attentions=None,
    )

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
    return {"input_ids": input_ids}


    if __name__ == "__main__":
    config = FakeTransformerConfig()
    model = FakeTransformer(config)

    input_ids = torch.tensor([[0]] * 10)
    output = model.generate(input_ids, max_length=4, do_sample=True, return_dict_in_generate=True, output_scores=True)
    print(f"generated tokens: {output.sequences}")
    print(f"probs at step 1: {torch.exp(output.scores[0][0,:])}")
    print(f"probs at step 2: {torch.exp(output.scores[1][0,:])}")
    print(f"probs at step 3: {torch.exp(output.scores[2][0,:])}")