from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS import os import requests import numpy as np import faiss from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage from getpass import getpass import logging from datetime import datetime, timedelta import json import logging, tqdm import asyncio import time from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings from langchain_nvidia_ai_endpoints import NVIDIARerank from langchain.retrievers.contextual_compression import ContextualCompressionRetriever # Define constants GLOBAL_PATH = "/scratch/gilbreth/gupt1075/oai_calls/" # Create logs directory if it doesn't exist os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True) # Set up logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, filename=f"{GLOBAL_PATH}/logs/{datetime.now().strftime('%B_%d_')}_NVIDIA_rerank.log", ) # def parse_json_to_numpy(json_text): # """ # Convert JSON text to a numpy array of embeddings. # Parameters: # json_text (str): JSON string containing embeddings. # Returns: # np.ndarray: Array of embeddings. # """ # data = json.loads(json_text) # embedding_list = [item["embedding"] for item in data["data"]] # embedding_array = np.array(embedding_list) # return embedding_array def pretty_print_docs(docs): """ Pretty print the documents. Parameters: docs (list): List of documents. """ logging.info( f"\n{'-' * 100}\n".join( [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] ) ) def load_document(file_path): """ Load the document using PyPDFLoader. Parameters: file_path (str): Path to the PDF file. Returns: document: Loaded document. """ return PyPDFLoader(file_path).load() def split_documents(document, chunk_size=800, chunk_overlap=200): """ Split the documents using RecursiveCharacterTextSplitter. Parameters: document: Loaded document. chunk_size (int): Size of each chunk. chunk_overlap (int): Overlap between chunks. Returns: list: List of split documents. """ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) return text_splitter.split_documents(document) def create_faiss_db(documents, embeddings): """ Create the FAISS database from the documents and embeddings. Parameters: documents (list): List of documents. embeddings: NVIDIAEmbeddings instance. Returns: FAISS: FAISS database instance. """ return FAISS.from_documents(documents, embeddings) def setup_retriever(db, k=45): """ Set up the retriever with search parameters. Parameters: db: FAISS database instance. k (int): Number of documents to retrieve. Returns: retriever: Retriever instance. """ return db.as_retriever(search_kwargs={"k": k}) def main(): document = load_document("./vlm.pdf") texts = split_documents(document) logging.warning(f"Type: {type(texts)} {type(texts)} Total Documents: {len(texts)} ") # Initialize the NVIDIAEmbeddings # Create the FAISS database from the documents and embeddings # Set up the retriever with search parameters embeddings = NVIDIAEmbeddings(nvidia_api_key=NVIDIA_API_KEY) db = create_faiss_db(texts, embeddings) retriever = setup_retriever(db) logging.warning(f"############# Starting the base retrieval ###############################") query = "What is the kind of model used?" # Retrieve documents based on the query docs = retriever.invoke(query) logging.warning(f"Total retrieved documents: {len(docs)} ") pretty_print_docs(docs) logging.warning(f"############# Starting the ReRanker ###############################") reranker = NVIDIARerank(nvidia_api_key=NVIDIA_API_KEY) # Initializing the ContextualCompressionRetriever with the reranker and retriever compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever) reranked_chunks = compression_retriever.invoke(query) logging.warning(f"Total reranked documents: {len(reranked_chunks)} ") pretty_print_docs(reranked_chunks) logging.warning(f"########################################################################################") if __name__ == "__main__": main()