Skip to content

Instantly share code, notes, and snippets.

@rishabh135
Last active August 2, 2024 00:08
Show Gist options
  • Save rishabh135/a968d3fd8ed7ab946e0597265122993e to your computer and use it in GitHub Desktop.
Save rishabh135/a968d3fd8ed7ab946e0597265122993e to your computer and use it in GitHub Desktop.
nvidia_reranker with working for a local pdf and pretty print in a log for example
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
import os
import requests
import numpy as np
import faiss
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from getpass import getpass
import logging
from datetime import datetime, timedelta
import json
import logging, tqdm
import asyncio
import time
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
# Define constants
GLOBAL_PATH = "/scratch/gilbreth/gupt1075/oai_calls/"
# Create logs directory if it doesn't exist
os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True)
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
filename=f"{GLOBAL_PATH}/logs/{datetime.now().strftime('%B_%d_')}_NVIDIA_rerank.log",
)
# def parse_json_to_numpy(json_text):
# """
# Convert JSON text to a numpy array of embeddings.
# Parameters:
# json_text (str): JSON string containing embeddings.
# Returns:
# np.ndarray: Array of embeddings.
# """
# data = json.loads(json_text)
# embedding_list = [item["embedding"] for item in data["data"]]
# embedding_array = np.array(embedding_list)
# return embedding_array
def pretty_print_docs(docs):
"""
Pretty print the documents.
Parameters:
docs (list): List of documents.
"""
logging.info(
f"\n{'-' * 100}\n".join(
[f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
)
)
def load_document(file_path):
"""
Load the document using PyPDFLoader.
Parameters:
file_path (str): Path to the PDF file.
Returns:
document: Loaded document.
"""
return PyPDFLoader(file_path).load()
def split_documents(document, chunk_size=800, chunk_overlap=200):
"""
Split the documents using RecursiveCharacterTextSplitter.
Parameters:
document: Loaded document.
chunk_size (int): Size of each chunk.
chunk_overlap (int): Overlap between chunks.
Returns:
list: List of split documents.
"""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(document)
def create_faiss_db(documents, embeddings):
"""
Create the FAISS database from the documents and embeddings.
Parameters:
documents (list): List of documents.
embeddings: NVIDIAEmbeddings instance.
Returns:
FAISS: FAISS database instance.
"""
return FAISS.from_documents(documents, embeddings)
def setup_retriever(db, k=45):
"""
Set up the retriever with search parameters.
Parameters:
db: FAISS database instance.
k (int): Number of documents to retrieve.
Returns:
retriever: Retriever instance.
"""
return db.as_retriever(search_kwargs={"k": k})
def main():
document = load_document("./vlm.pdf")
texts = split_documents(document)
logging.warning(f"Type: {type(texts)} {type(texts)} Total Documents: {len(texts)} ")
# Initialize the NVIDIAEmbeddings
# Create the FAISS database from the documents and embeddings
# Set up the retriever with search parameters
embeddings = NVIDIAEmbeddings(nvidia_api_key=NVIDIA_API_KEY)
db = create_faiss_db(texts, embeddings)
retriever = setup_retriever(db)
logging.warning(f"############# Starting the base retrieval ###############################")
query = "What is the kind of model used?"
# Retrieve documents based on the query
docs = retriever.invoke(query)
logging.warning(f"Total retrieved documents: {len(docs)} ")
pretty_print_docs(docs)
logging.warning(f"############# Starting the ReRanker ###############################")
reranker = NVIDIARerank(nvidia_api_key=NVIDIA_API_KEY)
# Initializing the ContextualCompressionRetriever with the reranker and retriever
compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
reranked_chunks = compression_retriever.invoke(query)
logging.warning(f"Total reranked documents: {len(reranked_chunks)} ")
pretty_print_docs(reranked_chunks)
logging.warning(f"########################################################################################")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment