Skip to content

Instantly share code, notes, and snippets.

@rishabh135
Created August 2, 2024 01:30
Show Gist options
  • Save rishabh135/782e7a1611dccd4d41157a716b585469 to your computer and use it in GitHub Desktop.
Save rishabh135/782e7a1611dccd4d41157a716b585469 to your computer and use it in GitHub Desktop.

Revisions

  1. rishabh135 created this gist Aug 2, 2024.
    212 changes: 212 additions & 0 deletions gff_jina_reranker.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,212 @@
    import os
    import requests
    import numpy as np
    import logging
    from datetime import datetime, timedelta
    import json
    import logging
    import asyncio
    import time
    import concurrent.futures
    import logging

    from langchain_community.document_loaders import PyPDFLoader
    from langchain_community.document_loaders import TextLoader
    from langchain_community.embeddings import JinaEmbeddings
    from langchain_text_splitters import RecursiveCharacterTextSplitter
    from langchain_community.vectorstores import FAISS
    import os
    import requests
    import numpy as np

    import logging
    from datetime import datetime, timedelta
    import json
    import logging
    import asyncio
    import time
    from tqdm import tqdm



    from langchain.retrievers import ContextualCompressionRetriever
    from langchain_community.document_compressors import JinaRerank



    # from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
    # from langchain_nvidia_ai_endpoints import NVIDIARerank
    # from langchain.retrievers.contextual_compression import ContextualCompressionRetriever


    # for making logging async
    # executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

    # def info(self, msg, *args):
    # executor.submit(logging.info, msg, *args)





    # Define constants
    GLOBAL_PATH = "/scratch/gilbreth/gupt1075/oai_calls/"

    # Create logs directory if it doesn't exist
    os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True)

    # Set up logging
    logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
    filename=f"{GLOBAL_PATH}/logs/{datetime.now().strftime('%B_%d_')}_jina_rerank_global_foundries.log",
    )






    def load_document(file_path):
    """
    Load the document using PyPDFLoader.
    Parameters:
    file_path (str): Path to the PDF file.
    Returns:
    document: Loaded document.
    """
    return PyPDFLoader(file_path).load()



    def split_documents(document, chunk_size=800, chunk_overlap=200):
    """
    Split the documents using RecursiveCharacterTextSplitter.
    Parameters:
    document: Loaded document.
    chunk_size (int): Size of each chunk.
    chunk_overlap (int): Overlap between chunks.
    Returns:
    list: List of split documents.
    """
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

    split_doc = text_splitter.split_documents(document)

    return split_doc


    # def create_faiss_db(documents, embeddings):
    # """
    # Create the FAISS database from the documents and embeddings.

    # Parameters:
    # documents (list): List of documents.
    # embeddings: NVIDIAEmbeddings instance.

    # Returns:
    # FAISS: FAISS database instance.
    # """
    # return FAISS.from_documents(documents, embeddings)

    # def setup_retriever(db, k=45):
    # """
    # Set up the retriever with search parameters.

    # Parameters:
    # db: FAISS database instance.
    # k (int): Number of documents to retrieve.

    # Returns:
    # retriever: Retriever instance.
    # """
    # return db.as_retriever(search_kwargs={"k": k})






    # Define Jina API endpoint and headers
    # JINA_API_ENDPOINT = 'https://api.jina.ai/v1/rerank'
    # JINA_API_HEADERS = {
    # 'Content-Type': 'application/json',
    # 'Authorization': 'Bearer jina_b039d4277d0c4729ada759a656cb40cb9cNn_RFeVBX5LJgzWwgCNfCnd1Zf'
    # }


    def pretty_print_docs(docs):
    logging.warning(
    f"\n{'-' * 100}\n".join(
    [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
    )
    )


    # Define function to pretty print results
    # def pretty_print_results(outp):
    # logging.warning(f" inside pretty print {outp} ")
    # if outp is None:
    # logging.info("No output available.")
    # return

    # results = outp.get('data')
    # if results is None:
    # logging.info("No results available in the output.")
    # return

    # logging.info("Results:")
    # for i, result in enumerate(results, start=1):
    # logging.info(f"Rank: {i}")
    # logging.info(f"Index: {result['embedding']}")
    # logging.info("")


    # Define function to run tasks
    def main():

    document = load_document("./vlm.pdf")
    split_docs = split_documents(document)

    logging.warning(f"Type: {type(split_docs)} Total Documents: {len(split_docs)} ")

    embedding = JinaEmbeddings ( model_name="jina-embeddings-v2-base-en", jina_api_key= JINA_API_KEY)
    retriever = FAISS.from_documents(split_docs, embedding).as_retriever( search_type="similarity", search_kwargs={"k": 25})

    query = "What kind of GPU models have been used?"
    docs = retriever.invoke(query )

    logging.warning("Starting Base retreiver " )
    pretty_print_docs(docs)

    logging.warning(f" { '****'*10} ")





    compressor = JinaRerank(top_n=3, model= "jina-reranker-v1-base-en" , jina_api_key=JINA_API_KEY)
    compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)

    compressed_docs = compression_retriever.invoke(query)

    logging.warning(f" { '****'*10} ")
    logging.warning("Starting Reranking " )
    pretty_print_docs(compressed_docs)
    # sentences_from_text = texts



    # Run tasks
    start_time = time.perf_counter()
    main()
    end = time.perf_counter()
    logging.warning(f" Ran the tasks in around {round(end-start_time, 4)} ")