import torch import torch.nn as nn import torch.nn.functional as F from transformers import GPT2Tokenizer, GPT2LMHeadModel ppl_model_name = 'gpt2-xl' if device == 'cuda' else 'gpt2' ppl_tokenizer = GPT2Tokenizer.from_pretrained(ppl_model_name) load_opts = { 'device_map': 'auto', 'torch_dtype': torch.float16, } if torch.cuda.is_available() else {} ppl_model = GPT2LMHeadModel.from_pretrained(ppl_model_name, **load_opts).to(device) def perplexities(text: str, stride: int = 128): tokenizer, model = ppl_tokenizer, ppl_model def tokenize(text: str) -> torch.LongTensor: return tokenizer(tokenizer.bos_token + text, return_tensors='pt').input_ids[0].to(device) def token_list(tokens: torch.LongTensor) -> List[int]: return tokenizer.batch_decode(tokens.unsqueeze(1)) max_length = model.config.n_positions input_ids = tokenize(text).to(device).unsqueeze(0) seq_len = input_ids.size(1) top_k = 10 tokens = [] for begin_loc in range(0, max(1, seq_len - max_length + stride), stride): end_loc = min(begin_loc + max_length, seq_len - 1) span_input_ids = input_ids[:, begin_loc:end_loc] target_ids = input_ids[:, begin_loc+1:end_loc+1] with torch.no_grad(): outputs = model(span_input_ids, labels=target_ids) logits = outputs.logits log_probs = F.log_softmax(logits, dim=-1) probs = F.softmax(logits, dim=-1) target_log_probs = log_probs.gather(2, target_ids.unsqueeze(2)).squeeze(2) target_probs = probs.gather(2, target_ids.unsqueeze(2)).squeeze(2) greedy_log_probs, greedy_tokens = log_probs.topk(top_k, dim=2) greedy_probs = torch.exp(greedy_log_probs) for tok, predicted_toks, log_prob, prob in list(zip( token_list(target_ids[0]), [ zip(topk_log_probs, topk_probs, token_list(topk_tokens)) for topk_log_probs, topk_probs, topk_tokens in zip( greedy_log_probs[0].tolist(), greedy_probs[0].tolist(), greedy_tokens[0], ) ], target_log_probs[0].tolist(), target_probs[0].tolist(), ))[max_length - stride if begin_loc > 0 else 0:]: tokens.append({ 'token': tok, 'predicted_tokens': [{ 'token': tok, 'log_prob': log_prob, 'prob': prob, } for log_prob, prob, tok in predicted_toks], 'log_prob': log_prob, 'prob': prob, }) return tokens