Created
September 4, 2023 20:23
-
-
Save thesephist/2ab7d80e6bd94d16f051b9f8fc62b289 to your computer and use it in GitHub Desktop.
Revisions
-
thesephist created this gist
Sep 4, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,67 @@ import torch import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Tokenizer, GPT2LMHeadModel ppl_model_name = 'gpt2-xl' if device == 'cuda' else 'gpt2' ppl_tokenizer = GPT2Tokenizer.from_pretrained(ppl_model_name) load_opts = { 'device_map': 'auto', 'torch_dtype': torch.float16, } if torch.cuda.is_available() else {} ppl_model = GPT2LMHeadModel.from_pretrained(ppl_model_name, **load_opts).to(device) def perplexities(text: str, stride: int = 128): tokenizer, model = ppl_tokenizer, ppl_model def tokenize(text: str) -> torch.LongTensor: return tokenizer(tokenizer.bos_token + text, return_tensors='pt').input_ids[0].to(device) def token_list(tokens: torch.LongTensor) -> List[int]: return tokenizer.batch_decode(tokens.unsqueeze(1)) max_length = model.config.n_positions input_ids = tokenize(text).to(device).unsqueeze(0) seq_len = input_ids.size(1) top_k = 10 tokens = [] for begin_loc in range(0, max(1, seq_len - max_length + stride), stride): end_loc = min(begin_loc + max_length, seq_len - 1) span_input_ids = input_ids[:, begin_loc:end_loc] target_ids = input_ids[:, begin_loc+1:end_loc+1] with torch.no_grad(): outputs = model(span_input_ids, labels=target_ids) logits = outputs.logits log_probs = F.log_softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1) target_log_probs = log_probs.gather(2, target_ids.unsqueeze(2)).squeeze(2) target_probs = probs.gather(2, target_ids.unsqueeze(2)).squeeze(2) greedy_log_probs, greedy_tokens = log_probs.topk(top_k, dim=2) greedy_probs = torch.exp(greedy_log_probs) for tok, predicted_toks, log_prob, prob in list(zip( token_list(target_ids[0]), [ zip(topk_log_probs, topk_probs, token_list(topk_tokens)) for topk_log_probs, topk_probs, topk_tokens in zip( greedy_log_probs[0].tolist(), greedy_probs[0].tolist(), greedy_tokens[0], ) ], target_log_probs[0].tolist(), target_probs[0].tolist(), ))[max_length - stride if begin_loc > 0 else 0:]: tokens.append({ 'token': tok, 'predicted_tokens': [{ 'token': tok, 'log_prob': log_prob, 'prob': prob, } for log_prob, prob, tok in predicted_toks], 'log_prob': log_prob, 'prob': prob, }) return tokens