Skip to content

Instantly share code, notes, and snippets.

@thesephist
Created September 4, 2023 20:23
Show Gist options
  • Select an option

  • Save thesephist/2ab7d80e6bd94d16f051b9f8fc62b289 to your computer and use it in GitHub Desktop.

Select an option

Save thesephist/2ab7d80e6bd94d16f051b9f8fc62b289 to your computer and use it in GitHub Desktop.

Revisions

  1. thesephist created this gist Sep 4, 2023.
    67 changes: 67 additions & 0 deletions gpt2_xl_perplexities.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,67 @@
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    from transformers import GPT2Tokenizer, GPT2LMHeadModel

    ppl_model_name = 'gpt2-xl' if device == 'cuda' else 'gpt2'
    ppl_tokenizer = GPT2Tokenizer.from_pretrained(ppl_model_name)
    load_opts = {
    'device_map': 'auto',
    'torch_dtype': torch.float16,
    } if torch.cuda.is_available() else {}
    ppl_model = GPT2LMHeadModel.from_pretrained(ppl_model_name, **load_opts).to(device)

    def perplexities(text: str, stride: int = 128):
    tokenizer, model = ppl_tokenizer, ppl_model

    def tokenize(text: str) -> torch.LongTensor:
    return tokenizer(tokenizer.bos_token + text, return_tensors='pt').input_ids[0].to(device)
    def token_list(tokens: torch.LongTensor) -> List[int]:
    return tokenizer.batch_decode(tokens.unsqueeze(1))

    max_length = model.config.n_positions
    input_ids = tokenize(text).to(device).unsqueeze(0)
    seq_len = input_ids.size(1)
    top_k = 10

    tokens = []
    for begin_loc in range(0, max(1, seq_len - max_length + stride), stride):
    end_loc = min(begin_loc + max_length, seq_len - 1)
    span_input_ids = input_ids[:, begin_loc:end_loc]
    target_ids = input_ids[:, begin_loc+1:end_loc+1]

    with torch.no_grad():
    outputs = model(span_input_ids, labels=target_ids)
    logits = outputs.logits
    log_probs = F.log_softmax(logits, dim=-1)
    probs = F.softmax(logits, dim=-1)
    target_log_probs = log_probs.gather(2, target_ids.unsqueeze(2)).squeeze(2)
    target_probs = probs.gather(2, target_ids.unsqueeze(2)).squeeze(2)
    greedy_log_probs, greedy_tokens = log_probs.topk(top_k, dim=2)
    greedy_probs = torch.exp(greedy_log_probs)
    for tok, predicted_toks, log_prob, prob in list(zip(
    token_list(target_ids[0]),
    [
    zip(topk_log_probs, topk_probs, token_list(topk_tokens))
    for topk_log_probs, topk_probs, topk_tokens
    in zip(
    greedy_log_probs[0].tolist(),
    greedy_probs[0].tolist(),
    greedy_tokens[0],
    )
    ],
    target_log_probs[0].tolist(),
    target_probs[0].tolist(),
    ))[max_length - stride if begin_loc > 0 else 0:]:
    tokens.append({
    'token': tok,
    'predicted_tokens': [{
    'token': tok,
    'log_prob': log_prob,
    'prob': prob,
    } for log_prob, prob, tok in predicted_toks],
    'log_prob': log_prob,
    'prob': prob,
    })
    return tokens