Created
August 2, 2024 01:30
-
-
Save rishabh135/782e7a1611dccd4d41157a716b585469 to your computer and use it in GitHub Desktop.
Revisions
-
rishabh135 created this gist
Aug 2, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,212 @@ import os import requests import numpy as np import logging from datetime import datetime, timedelta import json import logging import asyncio import time import concurrent.futures import logging from langchain_community.document_loaders import PyPDFLoader from langchain_community.document_loaders import TextLoader from langchain_community.embeddings import JinaEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS import os import requests import numpy as np import logging from datetime import datetime, timedelta import json import logging import asyncio import time from tqdm import tqdm from langchain.retrievers import ContextualCompressionRetriever from langchain_community.document_compressors import JinaRerank # from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # from langchain_nvidia_ai_endpoints import NVIDIARerank # from langchain.retrievers.contextual_compression import ContextualCompressionRetriever # for making logging async # executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) # def info(self, msg, *args): # executor.submit(logging.info, msg, *args) # Define constants GLOBAL_PATH = "/scratch/gilbreth/gupt1075/oai_calls/" # Create logs directory if it doesn't exist os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True) # Set up logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, filename=f"{GLOBAL_PATH}/logs/{datetime.now().strftime('%B_%d_')}_jina_rerank_global_foundries.log", ) def load_document(file_path): """ Load the document using PyPDFLoader. Parameters: file_path (str): Path to the PDF file. Returns: document: Loaded document. """ return PyPDFLoader(file_path).load() def split_documents(document, chunk_size=800, chunk_overlap=200): """ Split the documents using RecursiveCharacterTextSplitter. Parameters: document: Loaded document. chunk_size (int): Size of each chunk. chunk_overlap (int): Overlap between chunks. Returns: list: List of split documents. """ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) split_doc = text_splitter.split_documents(document) return split_doc # def create_faiss_db(documents, embeddings): # """ # Create the FAISS database from the documents and embeddings. # Parameters: # documents (list): List of documents. # embeddings: NVIDIAEmbeddings instance. # Returns: # FAISS: FAISS database instance. # """ # return FAISS.from_documents(documents, embeddings) # def setup_retriever(db, k=45): # """ # Set up the retriever with search parameters. # Parameters: # db: FAISS database instance. # k (int): Number of documents to retrieve. # Returns: # retriever: Retriever instance. # """ # return db.as_retriever(search_kwargs={"k": k}) # Define Jina API endpoint and headers # JINA_API_ENDPOINT = 'https://api.jina.ai/v1/rerank' # JINA_API_HEADERS = { # 'Content-Type': 'application/json', # 'Authorization': 'Bearer jina_b039d4277d0c4729ada759a656cb40cb9cNn_RFeVBX5LJgzWwgCNfCnd1Zf' # } def pretty_print_docs(docs): logging.warning( f"\n{'-' * 100}\n".join( [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] ) ) # Define function to pretty print results # def pretty_print_results(outp): # logging.warning(f" inside pretty print {outp} ") # if outp is None: # logging.info("No output available.") # return # results = outp.get('data') # if results is None: # logging.info("No results available in the output.") # return # logging.info("Results:") # for i, result in enumerate(results, start=1): # logging.info(f"Rank: {i}") # logging.info(f"Index: {result['embedding']}") # logging.info("") # Define function to run tasks def main(): document = load_document("./vlm.pdf") split_docs = split_documents(document) logging.warning(f"Type: {type(split_docs)} Total Documents: {len(split_docs)} ") embedding = JinaEmbeddings ( model_name="jina-embeddings-v2-base-en", jina_api_key= JINA_API_KEY) retriever = FAISS.from_documents(split_docs, embedding).as_retriever( search_type="similarity", search_kwargs={"k": 25}) query = "What kind of GPU models have been used?" docs = retriever.invoke(query ) logging.warning("Starting Base retreiver " ) pretty_print_docs(docs) logging.warning(f" { '****'*10} ") compressor = JinaRerank(top_n=3, model= "jina-reranker-v1-base-en" , jina_api_key=JINA_API_KEY) compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) compressed_docs = compression_retriever.invoke(query) logging.warning(f" { '****'*10} ") logging.warning("Starting Reranking " ) pretty_print_docs(compressed_docs) # sentences_from_text = texts # Run tasks start_time = time.perf_counter() main() end = time.perf_counter() logging.warning(f" Ran the tasks in around {round(end-start_time, 4)} ")