Skip to content

Instantly share code, notes, and snippets.

@s3nh
Last active January 24, 2024 18:08
Show Gist options
  • Save s3nh/cfbbf43f5e9e3cfe8c3e4e2f0d550b80 to your computer and use it in GitHub Desktop.
Save s3nh/cfbbf43f5e9e3cfe8c3e4e2f0d550b80 to your computer and use it in GitHub Desktop.

Revisions

  1. s3nh revised this gist Jan 24, 2024. No changes.
  2. s3nh renamed this gist Jan 24, 2024. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  3. s3nh created this gist Jan 24, 2024.
    114 changes: 114 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,114 @@
    import torch
    from chromadb import EmbeddingFunction
    from typing import List, Dict, Union
    from typing import Any, TypeVar


    INSTRUCTIONS = {
    "qa": {
    "query": "Represent this query for retrieving relevant documents: ",
    "key": "Represent this document for retrieval: ",
    },
    "icl": {
    "query": "Convert this example into vector to look for useful examples: ",
    "key": "Convert this example into vector for retrieval: ",
    },
    "chat": {
    "query": "Embed this dialogue to find useful historical dialogues: ",
    "key": "Embed this historical dialogue for retrieval: ",
    },
    "lrlm": {
    "query": "Embed this text chunk for finding useful historical chunks: ",
    "key": "Embed this historical text chunk for retrieval: ",
    },
    "tool": {
    "query": "Transform this user request for fetching helpful tool descriptions: ",
    "key": "Transform this tool description for retrieval: "
    },
    "convsearch": {
    "query": "Encode this query and context for searching relevant passages: ",
    "key": "Encode this passage for retrieval: ",
    },
    }


    class CustomCFG:
    model_name: str = '../assets/llm-embedder'
    local_files_only: bool = True
    max_length: int = 512
    padding: bool = True
    truncation: bool = True
    return_tensors: str = 'pt'
    chunk_size: int = 16
    pad_token: str = "PAD "
    model_half: bool = False
    instruction: str = 'convsearch'

    class CustomEmbeddingFunction(EmbeddingFunction):
    def __init__(self):
    self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    self.tokenizer = self.load_tokenizer()
    self.retriever = self.load_retriever()
    self.starter: str= 'Represent this sentence for searching relevant passages :'

    def __call__(self, texts: Union[str, List]):
    chunk_text = self.batch_chunk(texts)
    embeddings = self.batch_processing(chunk_text)
    return embeddings.tolist()

    def tokenize(self, texts):
    batch_dict = self.tokenizer(texts, #max_length = RetrieverCFG.max_length,
    padding = CustomCFG.padding,
    truncation = CustomCFG.truncation,
    return_tensors = CustomCFG.return_tensors).to(self.device)
    return batch_dict

    def average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    def load_retriever(self):
    retriever = AutoModel.from_pretrained(
    pretrained_model_name_or_path = CustomCFG.model_name,
    local_files_only = True,
    trust_remote_code=True
    )
    if CustomCFG.model_half:
    torch.set_default_dtype(torch.half)
    return retriever.to(self.device).half().eval()
    else:
    return retriever.to(self.device).eval()

    def process(self):
    if 'llm-embedder' in CustomCFG.model_name:
    instruction = INSTRUCTIONS[CustomCFG.instruction]
    self.input_text = [instruction["key"] + query for query in self.input_text]

    batch_dict = self.tokenizer(self.input_text, #max_length = RetrieverCFG.max_length,
    padding = CustomCFG.padding,
    truncation = CustomCFG.truncation,
    return_tensors = CustomCFG.return_tensors).to(self.device)

    with torch.no_grad():
    outputs = self.retriever(**batch_dict)
    if 'xge' in CustomCFG.model_name:
    embeddings = outputs[0][:, 0]
    embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    if 'llm-embedder' in CustomCFG.model_name:
    embeddings = outputs.last_hidden_state[:, 0]
    embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    else:
    embeddings = self.average_pool(
    last_hidden_states = outputs.last_hidden_state.to(self.device),
    attention_mask = batch_dict['attention_mask'].to(self.device)
    )
    return embeddings

    def batch_chunk(self):
    ...

    def process_once(self):
    ...

    def batch_processing(self):
    ...