Skip to content

Instantly share code, notes, and snippets.

@yuchenlin
Last active September 29, 2021 04:51
Show Gist options
  • Select an option

  • Save yuchenlin/4ae25d90eb229e4010db4c22dd512df6 to your computer and use it in GitHub Desktop.

Select an option

Save yuchenlin/4ae25d90eb229e4010db4c22dd512df6 to your computer and use it in GitHub Desktop.

Revisions

  1. yuchenlin revised this gist Sep 29, 2021. 1 changed file with 6 additions and 12 deletions.
    18 changes: 6 additions & 12 deletions encoding.py
    Original file line number Diff line number Diff line change
    @@ -27,6 +27,7 @@ def __init__(self, logger):
    self.initial_memory_path = ""
    self.data_args = None
    self.dim_vector = 2*768
    self.memory_index_sorted_ids = []

    def set_up_data_args(self, args):
    self.data_args = Namespace(
    @@ -58,25 +59,18 @@ def set_up_initial_memory(self):
    self.memory_examples[item[2]] = (item[0], item[1][0:1], item[2])
    self.logger.info(f"Set up the initial memory with {len(self.memory_examples)} examples.")
    initial_memory_example_ids = sorted(list(self.memory_examples.keys()))
    self.memory_index_sorted_ids = initial_memory_example_ids
    vectors = self.get_representation(initial_memory_example_ids)
    self.update_index(initial_memory_example_ids, vectors)


    def update_index(self, example_ids, vectors):
    assert len(example_ids) == len(vectors)
    if not self.memory_index:
    self.init_index(vectors)
    else:
    # TODO: update faiss index?
    pass


    def init_index(self, vectors):
    self.memory_index = faiss.IndexFlatL2(self.dim_vector)
    self.memory_index = faiss.IndexFlatL2(self.dim_vector)
    self.memory_index_sorted_ids += example_ids
    vectors = np.array(vectors)
    self.memory_index.add(vectors)


    def query_examples(self, query_vectors, k=5):
    """
    @@ -102,8 +96,8 @@ def get_representation(self, example_ids):
    examples = self.get_examples_by_ids(example_ids)

    data_manager, _ = self.cl_utils.get_dataloader(self.data_args, examples, mode="train", is_training=False)
    all_vectors = []
    bart_model = self.bart_model.module
    all_vectors = []
    bart_model = self.bart_model if self.cl_utils.n_gpu ==1 else self.bart_model.module
    bart_model.eval()
    for batch in tqdm(data_manager.dataloader):
    # self.logger.info(f"len(batch)={len(batch)}")
  2. yuchenlin revised this gist Sep 29, 2021. 1 changed file with 183 additions and 57 deletions.
    240 changes: 183 additions & 57 deletions encoding.py
    Original file line number Diff line number Diff line change
    @@ -1,63 +1,189 @@
    def get_representation(self, example_ids):
    examples = self.get_examples_by_ids(example_ids)

    # Note that we need to make sure this is a SquentialSampler instead of a RandomSampler
    data_manager, _ = self.cl_utils.get_dataloader(self.data_args, examples, mode="train", is_training=False)
    all_tensors = []
    for batch in tqdm(data_manager.dataloader):
    # self.logger.info(f"len(batch)={len(batch)}")
    if self.cl_utils.use_cuda:
    # print(type(batch[0]), batch[0])
    batch = [b.to(torch.device("cuda")) for b in batch]
    pad_token_id = self.tokenizer.pad_token_id
    batch[0], batch[1] = trim_batch(
    batch[0], pad_token_id, batch[1])
    batch[2], batch[3] = trim_batch(
    batch[2], pad_token_id, batch[3])

    ## Encode the input text with BART-encoder
    input_ids = batch[0]
    input_attention_mask = batch[1]
    encoder_outputs = self.bart_model.model.encoder(
    input_ids, input_attention_mask)
    x = encoder_outputs[0]
    x = x[:, 0, :]
    input_vectors = x.detach().cpu().numpy()

    self.logger.info(f"input_vectors.shape = {input_vectors.shape}")

    ## Encode the output text with BART-decoder

    output_ids = batch[2]
    output_attention_mask = batch[3]

    decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
    self.bart_model.model.config,
    input_ids,
    decoder_input_ids=output_ids,
    decoder_padding_mask=output_attention_mask,
    causal_mask_dtype=self.bart_model.model.shared.weight.dtype,
    )
    decoder_outputs = self.bart_model.model.decoder(
    decoder_input_ids,
    encoder_outputs[0],
    input_attention_mask,
    decoder_padding_mask,
    decoder_causal_mask=causal_mask,
    decoder_cached_states=None,
    use_cache=False
    from argparse import Namespace

    from numpy.core.defchararray import index
    from semanticdebugger.debug_algs.cl_simple_alg import ContinualFinetuning
    from tqdm import tqdm
    import torch
    from semanticdebugger.models.utils import trim_batch
    import json
    from semanticdebugger.debug_algs import run_lifelong_finetune

    from transformers.modeling_bart import _prepare_bart_decoder_inputs
    import numpy as np
    import faiss


    class IndexManager():

    def __init__(self, logger):
    super().__init__()
    self.logger=logger
    self.name = "index_manager"
    self.memory_index = None
    self.memory_examples = {}
    self.bart_model = None
    self.tokenizer = None
    self.cl_utils = ContinualFinetuning(logger=logger)
    self.initial_memory_path = ""
    self.data_args = None
    self.dim_vector = 2*768

    def set_up_data_args(self, args):
    self.data_args = Namespace(
    bug_stream_json_path="",
    pass_pool_jsonl_path="",
    sampled_upstream_json_path="",
    # pass_sample_size=args.pass_sample_size,
    do_lowercase=args.do_lowercase,
    append_another_bos=args.append_another_bos,
    max_input_length=args.max_input_length,
    max_output_length=args.max_output_length,
    task_name=args.task_name,
    train_batch_size=args.train_batch_size,
    predict_batch_size=args.predict_batch_size,
    num_beams=args.num_beams,
    max_timecode=args.max_timecode,
    accumulate_eval_freq=-1,
    use_sampled_upstream=False,
    )
    y = decoder_outputs[0]
    y = y[:, 0, :]
    output_vectors = y.detach().cpu().numpy()

    self.logger.info(f"output_vectors.shape = {output_vectors.shape}")
    def set_up_initial_memory(self):
    assert self.bart_model is not None

    with open(self.initial_memory_path) as f:
    initial_memory_examples = [json.loads(line) for line in set(f.read().splitlines())][:500] # TODO: only for debugging
    initial_memory_examples = self.cl_utils.upstream_data_formatter(initial_memory_examples)
    for item in initial_memory_examples:
    # Note that we only keep the first answer here now.
    self.memory_examples[item[2]] = (item[0], item[1][0:1], item[2])
    self.logger.info(f"Set up the initial memory with {len(self.memory_examples)} examples.")
    initial_memory_example_ids = sorted(list(self.memory_examples.keys()))
    self.memory_index_sorted_ids = initial_memory_example_ids
    vectors = self.get_representation(initial_memory_example_ids)
    self.update_index(initial_memory_example_ids, vectors)


    def update_index(self, example_ids, vectors):
    assert len(example_ids) == len(vectors)
    if not self.memory_index:
    self.init_index(vectors)
    else:
    # TODO: update faiss index?
    pass


    def init_index(self, vectors):
    self.memory_index = faiss.IndexFlatL2(self.dim_vector)
    vectors = np.array(vectors)
    self.memory_index.add(vectors)


    # Concatenate
    vectors = np.concatenate([input_vectors, output_vectors], axis=1)
    def query_examples(self, query_vectors, k=5):
    """
    Returns samples from buffer using K-nearest neighbour approach
    """
    retrieved_examples = []
    for query_vector in query_vectors:
    D, I = self.memory_index.search(np.array([query_vector]), k)
    retrieved_examples.append([self.memory_index_sorted_ids[int(eid)] for eid in I[0]])
    return retrieved_examples

    def set_up_model(self, model, tokenizer):
    self.bart_model = model
    self.tokenizer = tokenizer

    def get_examples_by_ids(self, example_ids):
    return [self.memory_examples[eid] for eid in example_ids]




    def get_representation(self, example_ids):
    examples = self.get_examples_by_ids(example_ids)

    data_manager, _ = self.cl_utils.get_dataloader(self.data_args, examples, mode="train", is_training=False)
    all_vectors = []
    bart_model = self.bart_model.module
    bart_model.eval()
    for batch in tqdm(data_manager.dataloader):
    # self.logger.info(f"len(batch)={len(batch)}")
    if self.cl_utils.use_cuda:
    # print(type(batch[0]), batch[0])
    batch = [b.to(torch.device("cuda")) for b in batch]
    pad_token_id = self.tokenizer.pad_token_id
    batch[0], batch[1] = trim_batch(
    batch[0], pad_token_id, batch[1])
    batch[2], batch[3] = trim_batch(
    batch[2], pad_token_id, batch[3])

    ## Encode the input text with BART-encoder
    input_ids = batch[0]
    input_attention_mask = batch[1]
    encoder_outputs = bart_model.model.encoder(
    input_ids, input_attention_mask)
    x = encoder_outputs[0]
    x = x[:, 0, :]
    input_vectors = x.detach().cpu().numpy()


    # self.logger.info(f"input_vectors.shape = {input_vectors.shape}")

    ## Encode the output text with BART-decoder

    output_ids = batch[2]
    output_attention_mask = batch[3]

    decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
    bart_model.model.config,
    input_ids,
    decoder_input_ids=output_ids,
    decoder_padding_mask=output_attention_mask,
    causal_mask_dtype=bart_model.model.shared.weight.dtype,
    )
    decoder_outputs = bart_model.model.decoder(
    decoder_input_ids,
    encoder_outputs[0],
    input_attention_mask,
    decoder_padding_mask,
    decoder_causal_mask=causal_mask,
    decoder_cached_states=None,
    use_cache=False
    )
    y = decoder_outputs[0]
    y = y[:, 0, :]
    output_vectors = y.detach().cpu().numpy()

    del batch
    del encoder_outputs
    del decoder_outputs
    # self.logger.info(f"output_vectors.shape = {output_vectors.shape}")

    # concatenate the vectors
    vectors = np.concatenate([input_vectors, output_vectors], axis=1)

    # self.logger.info(f"vectors.shape = {vectors.shape}")
    all_vectors += list(vectors)
    return all_vectors



    if __name__ == '__main__':
    parser = run_lifelong_finetune.get_cli_parser()
    args = parser.parse_args()

    self.logger.info(f"vectors.shape = {vectors.shape}")
    all_tensors += list(vectors)
    debugging_alg, data_args, base_model_args, debugger_args, logger = run_lifelong_finetune.setup_args(
    args)
    args.predict_batch_size = 8
    index_manager = IndexManager(logger=logger)
    index_manager.set_up_data_args(args)
    index_manager.cl_utils.load_base_model(base_model_args)
    index_manager.set_up_model(model=index_manager.cl_utils.base_model, tokenizer=index_manager.cl_utils.tokenizer)
    index_manager.initial_memory_path = "exp_results/data_streams/mrqa.nq_train.memory.jsonl"

    index_manager.set_up_initial_memory()

    return all_tensors
    query_ids = list(index_manager.memory_examples.keys())[:3]
    tensors = index_manager.get_representation(example_ids=query_ids)
    print(query_ids, tensors)
    retrived_ids = index_manager.query_examples(tensors)
    print(retrived_ids)
  3. yuchenlin created this gist Sep 29, 2021.
    63 changes: 63 additions & 0 deletions encoding.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,63 @@
    def get_representation(self, example_ids):
    examples = self.get_examples_by_ids(example_ids)

    # Note that we need to make sure this is a SquentialSampler instead of a RandomSampler
    data_manager, _ = self.cl_utils.get_dataloader(self.data_args, examples, mode="train", is_training=False)
    all_tensors = []
    for batch in tqdm(data_manager.dataloader):
    # self.logger.info(f"len(batch)={len(batch)}")
    if self.cl_utils.use_cuda:
    # print(type(batch[0]), batch[0])
    batch = [b.to(torch.device("cuda")) for b in batch]
    pad_token_id = self.tokenizer.pad_token_id
    batch[0], batch[1] = trim_batch(
    batch[0], pad_token_id, batch[1])
    batch[2], batch[3] = trim_batch(
    batch[2], pad_token_id, batch[3])

    ## Encode the input text with BART-encoder
    input_ids = batch[0]
    input_attention_mask = batch[1]
    encoder_outputs = self.bart_model.model.encoder(
    input_ids, input_attention_mask)
    x = encoder_outputs[0]
    x = x[:, 0, :]
    input_vectors = x.detach().cpu().numpy()

    self.logger.info(f"input_vectors.shape = {input_vectors.shape}")

    ## Encode the output text with BART-decoder

    output_ids = batch[2]
    output_attention_mask = batch[3]

    decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
    self.bart_model.model.config,
    input_ids,
    decoder_input_ids=output_ids,
    decoder_padding_mask=output_attention_mask,
    causal_mask_dtype=self.bart_model.model.shared.weight.dtype,
    )
    decoder_outputs = self.bart_model.model.decoder(
    decoder_input_ids,
    encoder_outputs[0],
    input_attention_mask,
    decoder_padding_mask,
    decoder_causal_mask=causal_mask,
    decoder_cached_states=None,
    use_cache=False
    )
    y = decoder_outputs[0]
    y = y[:, 0, :]
    output_vectors = y.detach().cpu().numpy()

    self.logger.info(f"output_vectors.shape = {output_vectors.shape}")

    # Concatenate
    vectors = np.concatenate([input_vectors, output_vectors], axis=1)


    self.logger.info(f"vectors.shape = {vectors.shape}")
    all_tensors += list(vectors)

    return all_tensors