import enum import os import string from typing import List, Mapping import marisa_trie import torch class GenerationType(enum.Enum): ALL_TOKENS = "all" THREE_TOKENS = "gen3" TWO_TOKENS = "gen2" ONE_TOKEN = "gen1" OPTION_ID = "abcd" def get_model_full_name( model_ckpt, gen_type: GenerationType, epochs: int, data_version: str = "0" ): model_base_name = model_ckpt.split("/")[-1] gen_type = gen_type.value model_full_name = f"{model_base_name}_{gen_type}_ep{epochs}_dt{data_version}" return model_full_name def get_outdir(model_name: str): return os.path.join("./classifers/", model_name) def get_prediction_name(model_name: str): return f"{model_name}_predicted" class MarisaTrie(object): def __init__( self, sequences: List[List[int]] = [], cache_fist_branch=True, max_token_id=256001, ): self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + ( [chr(i) for i in range(65000, max_token_id + 10000)] if max_token_id >= 55000 else [] ) self.char2int = {self.int2char[i]: i for i in range(max_token_id)} self.cache_fist_branch = cache_fist_branch if self.cache_fist_branch: self.zero_iter = list({sequence[0] for sequence in sequences}) assert len(self.zero_iter) == 1 self.first_iter = list({sequence[1] for sequence in sequences}) self.trie = marisa_trie.Trie( "".join([self.int2char[i] for i in sequence]) for sequence in sequences ) def get(self, prefix_sequence: List[int]): if self.cache_fist_branch and len(prefix_sequence) == 0: return self.zero_iter elif ( self.cache_fist_branch and len(prefix_sequence) == 1 and self.zero_iter == prefix_sequence ): return self.first_iter else: key = "".join([self.int2char[i] for i in prefix_sequence]) return list( { self.char2int[e[len(key)]] for e in self.trie.keys(key) if len(e) > len(key) } ) def __iter__(self): for sequence in self.trie.iterkeys(): yield [self.char2int[e] for e in sequence] def __len__(self): return len(self.trie) def __getitem__(self, value): return self.get(value) def map_class_name(tokenizer, class_raw_name, num_tokens=None, delim=" "): if num_tokens is None: return class_raw_name class_raw_words = class_raw_name.split(delim) for i in range(1, len(class_raw_words) + 1): class_name_candidate = " ".join(class_raw_words[:i]) tokens = tokenizer.tokenize(class_name_candidate) if len(tokens) == num_tokens: return class_name_candidate raise ValueError( f"Cannot find class name at the specificed num_tokens: {class_raw_name}, {num_tokens}" ) def create_class_text_map(tokenizer, class_raw_names, num_tokens, delim="_"): res = {} for raw_name in class_raw_names: res[raw_name] = map_class_name(tokenizer, raw_name, num_tokens, delim=delim) return res def get_task_prefix(gen_type, class_text_map): task_prefix = "Classify query intent into one of the following categories: " if ( gen_type == GenerationType.TWO_TOKENS or gen_type == GenerationType.ONE_TOKEN or gen_type == GenerationType.ALL_TOKENS ): classes = [f"'{x}'" for x in class_text_map.values()] task_prefix += ", ".join(classes) task_prefix += ". query: " elif gen_type == GenerationType.OPTION_ID: classes = [ f"{val}: {' '.join(key.split('_')[:-1])}" for key, val in class_text_map.items() ] task_prefix += "\n" + "\n".join(classes) task_prefix += "\nquery: " return task_prefix def get_gen_type_attributes(gen_type, tokenizer, class_names): if gen_type == GenerationType.THREE_TOKENS: class_text_map = create_class_text_map(tokenizer, class_names, 3) max_decoding_length = 3 elif gen_type == GenerationType.TWO_TOKENS: class_text_map = create_class_text_map(tokenizer, class_names, 2) max_decoding_length = 2 elif gen_type == GenerationType.ONE_TOKEN: class_text_map = create_class_text_map(tokenizer, class_names, 1) max_decoding_length = 1 elif gen_type == GenerationType.ALL_TOKENS: class_text_map = create_class_text_map(tokenizer, class_names, None) max_decoding_length = max( [len([0] + tokenizer.encode(x)) for x in class_text_map.values()] ) elif gen_type == GenerationType.OPTION_ID: class_text_map = {} for i, raw_name in enumerate(class_names): class_text_map[raw_name] = string.ascii_uppercase[i] max_decoding_length = 2 else: raise ValueError(f"Non-existent `gen_type`: {gen_type}") task_prefix = get_task_prefix(gen_type, class_text_map) return class_text_map, max_decoding_length, task_prefix def convert_to_features( example_batch, class_text_map: Mapping[str, str], task_prefix: str, input_max_length=512, label_max_length=16, query_key="query", label_key="expected_single", class_names=None, tokenizer=None, ): q = example_batch[query_key] example_batch["input_text"] = f"{task_prefix}{q}" input_encodings = tokenizer( example_batch["input_text"], padding="max_length", max_length=input_max_length, truncation=True, ) encodings = { "inputs": example_batch["input_text"], "input_ids": input_encodings["input_ids"], "attention_mask": input_encodings["attention_mask"], } if label_key: label = class_text_map[class_names[example_batch[label_key]]] example_batch["target_text"] = f"{label}" target_encodings = tokenizer( example_batch["target_text"], padding="max_length", max_length=label_max_length, truncation=True, ) encodings["labels"] = target_encodings["input_ids"] return encodings def preprocess_logits_for_metrics(logits, labels): """ Original Trainer may have a memory leak. This is a workaround to avoid storing too many tensors that are not needed. """ pred_ids = torch.argmax(logits[0], dim=-1) return pred_ids, labels def build_prefix_allowed_tokens_fn(allowed_sequences): """Returns a function that provides next allowed tokens based on the prefix `seq`.""" t = MarisaTrie(allowed_sequences) def fn(unused_batch_id, seq): return t.get(seq) return fn def process_golden_labels(example_batch, class_text_map, class_names): def fn(expected_text): return [class_text_map[y.strip()] for y in expected_text.split(",")] # example_batch['golden_labels'] = fn(example_batch['expected']) example_batch["golden_labels"] = fn(class_names[example_batch["Label"]]) return example_batch def compute_metrics(preds): return {}