Last active
August 2, 2024 00:08
-
-
Save rishabh135/a968d3fd8ed7ab946e0597265122993e to your computer and use it in GitHub Desktop.
nvidia_reranker with working for a local pdf and pretty print in a log for example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| import os | |
| import requests | |
| import numpy as np | |
| import faiss | |
| from mistralai.client import MistralClient | |
| from mistralai.models.chat_completion import ChatMessage | |
| from getpass import getpass | |
| import logging | |
| from datetime import datetime, timedelta | |
| import json | |
| import logging, tqdm | |
| import asyncio | |
| import time | |
| from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings | |
| from langchain_nvidia_ai_endpoints import NVIDIARerank | |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
| # Define constants | |
| GLOBAL_PATH = "/scratch/gilbreth/gupt1075/oai_calls/" | |
| # Create logs directory if it doesn't exist | |
| os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True) | |
| # Set up logging | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| filename=f"{GLOBAL_PATH}/logs/{datetime.now().strftime('%B_%d_')}_NVIDIA_rerank.log", | |
| ) | |
| # def parse_json_to_numpy(json_text): | |
| # """ | |
| # Convert JSON text to a numpy array of embeddings. | |
| # Parameters: | |
| # json_text (str): JSON string containing embeddings. | |
| # Returns: | |
| # np.ndarray: Array of embeddings. | |
| # """ | |
| # data = json.loads(json_text) | |
| # embedding_list = [item["embedding"] for item in data["data"]] | |
| # embedding_array = np.array(embedding_list) | |
| # return embedding_array | |
| def pretty_print_docs(docs): | |
| """ | |
| Pretty print the documents. | |
| Parameters: | |
| docs (list): List of documents. | |
| """ | |
| logging.info( | |
| f"\n{'-' * 100}\n".join( | |
| [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] | |
| ) | |
| ) | |
| def load_document(file_path): | |
| """ | |
| Load the document using PyPDFLoader. | |
| Parameters: | |
| file_path (str): Path to the PDF file. | |
| Returns: | |
| document: Loaded document. | |
| """ | |
| return PyPDFLoader(file_path).load() | |
| def split_documents(document, chunk_size=800, chunk_overlap=200): | |
| """ | |
| Split the documents using RecursiveCharacterTextSplitter. | |
| Parameters: | |
| document: Loaded document. | |
| chunk_size (int): Size of each chunk. | |
| chunk_overlap (int): Overlap between chunks. | |
| Returns: | |
| list: List of split documents. | |
| """ | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| return text_splitter.split_documents(document) | |
| def create_faiss_db(documents, embeddings): | |
| """ | |
| Create the FAISS database from the documents and embeddings. | |
| Parameters: | |
| documents (list): List of documents. | |
| embeddings: NVIDIAEmbeddings instance. | |
| Returns: | |
| FAISS: FAISS database instance. | |
| """ | |
| return FAISS.from_documents(documents, embeddings) | |
| def setup_retriever(db, k=45): | |
| """ | |
| Set up the retriever with search parameters. | |
| Parameters: | |
| db: FAISS database instance. | |
| k (int): Number of documents to retrieve. | |
| Returns: | |
| retriever: Retriever instance. | |
| """ | |
| return db.as_retriever(search_kwargs={"k": k}) | |
| def main(): | |
| document = load_document("./vlm.pdf") | |
| texts = split_documents(document) | |
| logging.warning(f"Type: {type(texts)} {type(texts)} Total Documents: {len(texts)} ") | |
| # Initialize the NVIDIAEmbeddings | |
| # Create the FAISS database from the documents and embeddings | |
| # Set up the retriever with search parameters | |
| embeddings = NVIDIAEmbeddings(nvidia_api_key=NVIDIA_API_KEY) | |
| db = create_faiss_db(texts, embeddings) | |
| retriever = setup_retriever(db) | |
| logging.warning(f"############# Starting the base retrieval ###############################") | |
| query = "What is the kind of model used?" | |
| # Retrieve documents based on the query | |
| docs = retriever.invoke(query) | |
| logging.warning(f"Total retrieved documents: {len(docs)} ") | |
| pretty_print_docs(docs) | |
| logging.warning(f"############# Starting the ReRanker ###############################") | |
| reranker = NVIDIARerank(nvidia_api_key=NVIDIA_API_KEY) | |
| # Initializing the ContextualCompressionRetriever with the reranker and retriever | |
| compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever) | |
| reranked_chunks = compression_retriever.invoke(query) | |
| logging.warning(f"Total reranked documents: {len(reranked_chunks)} ") | |
| pretty_print_docs(reranked_chunks) | |
| logging.warning(f"########################################################################################") | |
| if __name__ == "__main__": | |
| main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment