Skip to content

Instantly share code, notes, and snippets.

@manisnesan
Created October 29, 2024 16:38
Show Gist options
  • Select an option

  • Save manisnesan/11a29bc2bdf681bc927d20da752d6b64 to your computer and use it in GitHub Desktop.

Select an option

Save manisnesan/11a29bc2bdf681bc927d20da752d6b64 to your computer and use it in GitHub Desktop.

Revisions

  1. manisnesan created this gist Oct 29, 2024.
    103 changes: 103 additions & 0 deletions caikit-reranker.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,103 @@
    import requests

    UNIFIED_FIELDS = ["title", "text"]
    CASE_FIELDS = ["summary", "product", "description"]
    SOLUTION_FIELDS = [
    "title", "solution_environment", "issue",
    "solution_rootcause", "solution_diagnosticsteps",
    "solution_resolution"
    ]

    def join_fields(doc, fields):
    return " ".join(doc[field] for field in fields if field in doc)

    def trim_string(doc):
    if "summary" in doc:
    return join_fields(doc, CASE_FIELDS)
    elif "title" in doc:
    return join_fields(doc, SOLUTION_FIELDS)
    elif "title" in doc:
    return join_fields(doc, UNIFIED_FIELDS)
    else:
    raise ValueError("Unknown document type with fields: " + ", ".join(doc.keys()))

    def rerank(server, port, query, documents, top_k=10):
    base_url = f"http://{server}:{port}/api/v1/senttransformer/task/rerank"

    if not documents:
    print("Reranking cannot be performed for null or empty documents")
    return []

    query_text = trim_string(query)
    doc_array = [
    {
    "document": {
    "text": trim_string(doc),
    "title": doc.get("title"),
    "url": doc.get("uri")
    }
    }
    for doc in documents[:top_k]
    ]

    payload = {
    "inputs": {
    "queries": [query_text],
    "documents": {"documents": doc_array},
    "top_n": top_k
    }
    }

    response = requests.post(base_url, json=payload)
    response.raise_for_status() # Raise an error for bad responses

    results = response.json().get("results", [])
    return [
    (documents[entry["corpus_id"]]["uri"], entry["score"])
    for entry in results[:top_k]
    ]

    # Example usage:
    # Sample data for demonstration
    server_address = "example.com" # Replace with your server address
    port = 8080 # Replace with your port number

    # Sample query and documents
    query = {
    "summary": "Need help with installation issues",
    "product": "Red Hat Enterprise Linux",
    "description": "User is facing problems during installation."
    }

    documents = [
    {
    "title": "Installation Guide for RHEL",
    "uri": "http://example.com/docs/rhel-installation",
    "summary": "A comprehensive guide to install RHEL.",
    "product": "Red Hat Enterprise Linux",
    "description": "This document provides step-by-step instructions."
    },
    {
    "title": "Troubleshooting RHEL Installation",
    "uri": "http://example.com/docs/rhel-troubleshooting",
    "summary": "Common issues and solutions during RHEL installation.",
    "product": "Red Hat Enterprise Linux",
    "description": "This document helps troubleshoot installation problems."
    },
    {
    "title": "RHEL Support",
    "uri": "http://example.com/docs/rhel-support",
    "summary": "Support options for RHEL users.",
    "product": "Red Hat Enterprise Linux",
    "description": "Information on how to get support for RHEL."
    }
    ]

    # Call the rerank function
    try:
    results = rerank(server_address, port, query, documents, top_k=10)
    print("Reranked Results:")
    for uri, score in results:
    print(f"Document URI: {uri}, Score: {score}")
    except Exception as e:
    print(f"An error occurred: {e}")