import torch from chromadb import EmbeddingFunction from typing import List, Dict, Union from typing import Any, TypeVar INSTRUCTIONS = { "qa": { "query": "Represent this query for retrieving relevant documents: ", "key": "Represent this document for retrieval: ", }, "icl": { "query": "Convert this example into vector to look for useful examples: ", "key": "Convert this example into vector for retrieval: ", }, "chat": { "query": "Embed this dialogue to find useful historical dialogues: ", "key": "Embed this historical dialogue for retrieval: ", }, "lrlm": { "query": "Embed this text chunk for finding useful historical chunks: ", "key": "Embed this historical text chunk for retrieval: ", }, "tool": { "query": "Transform this user request for fetching helpful tool descriptions: ", "key": "Transform this tool description for retrieval: " }, "convsearch": { "query": "Encode this query and context for searching relevant passages: ", "key": "Encode this passage for retrieval: ", }, } class CustomCFG: model_name: str = '../assets/llm-embedder' local_files_only: bool = True max_length: int = 512 padding: bool = True truncation: bool = True return_tensors: str = 'pt' chunk_size: int = 16 pad_token: str = "PAD " model_half: bool = False instruction: str = 'convsearch' class CustomEmbeddingFunction(EmbeddingFunction): def __init__(self): self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.tokenizer = self.load_tokenizer() self.retriever = self.load_retriever() self.starter: str= 'Represent this sentence for searching relevant passages :' def __call__(self, texts: Union[str, List]): chunk_text = self.batch_chunk(texts) embeddings = self.batch_processing(chunk_text) return embeddings.tolist() def tokenize(self, texts): batch_dict = self.tokenizer(texts, #max_length = RetrieverCFG.max_length, padding = CustomCFG.padding, truncation = CustomCFG.truncation, return_tensors = CustomCFG.return_tensors).to(self.device) return batch_dict def average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def load_retriever(self): retriever = AutoModel.from_pretrained( pretrained_model_name_or_path = CustomCFG.model_name, local_files_only = True, trust_remote_code=True ) if CustomCFG.model_half: torch.set_default_dtype(torch.half) return retriever.to(self.device).half().eval() else: return retriever.to(self.device).eval() def process(self): if 'llm-embedder' in CustomCFG.model_name: instruction = INSTRUCTIONS[CustomCFG.instruction] self.input_text = [instruction["key"] + query for query in self.input_text] batch_dict = self.tokenizer(self.input_text, #max_length = RetrieverCFG.max_length, padding = CustomCFG.padding, truncation = CustomCFG.truncation, return_tensors = CustomCFG.return_tensors).to(self.device) with torch.no_grad(): outputs = self.retriever(**batch_dict) if 'xge' in CustomCFG.model_name: embeddings = outputs[0][:, 0] embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) if 'llm-embedder' in CustomCFG.model_name: embeddings = outputs.last_hidden_state[:, 0] embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) else: embeddings = self.average_pool( last_hidden_states = outputs.last_hidden_state.to(self.device), attention_mask = batch_dict['attention_mask'].to(self.device) ) return embeddings def batch_chunk(self): ... def process_once(self): ... def batch_processing(self): ...