Created
October 29, 2024 16:38
-
-
Save manisnesan/11a29bc2bdf681bc927d20da752d6b64 to your computer and use it in GitHub Desktop.
Revisions
-
manisnesan created this gist
Oct 29, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,103 @@ import requests UNIFIED_FIELDS = ["title", "text"] CASE_FIELDS = ["summary", "product", "description"] SOLUTION_FIELDS = [ "title", "solution_environment", "issue", "solution_rootcause", "solution_diagnosticsteps", "solution_resolution" ] def join_fields(doc, fields): return " ".join(doc[field] for field in fields if field in doc) def trim_string(doc): if "summary" in doc: return join_fields(doc, CASE_FIELDS) elif "title" in doc: return join_fields(doc, SOLUTION_FIELDS) elif "title" in doc: return join_fields(doc, UNIFIED_FIELDS) else: raise ValueError("Unknown document type with fields: " + ", ".join(doc.keys())) def rerank(server, port, query, documents, top_k=10): base_url = f"http://{server}:{port}/api/v1/senttransformer/task/rerank" if not documents: print("Reranking cannot be performed for null or empty documents") return [] query_text = trim_string(query) doc_array = [ { "document": { "text": trim_string(doc), "title": doc.get("title"), "url": doc.get("uri") } } for doc in documents[:top_k] ] payload = { "inputs": { "queries": [query_text], "documents": {"documents": doc_array}, "top_n": top_k } } response = requests.post(base_url, json=payload) response.raise_for_status() # Raise an error for bad responses results = response.json().get("results", []) return [ (documents[entry["corpus_id"]]["uri"], entry["score"]) for entry in results[:top_k] ] # Example usage: # Sample data for demonstration server_address = "example.com" # Replace with your server address port = 8080 # Replace with your port number # Sample query and documents query = { "summary": "Need help with installation issues", "product": "Red Hat Enterprise Linux", "description": "User is facing problems during installation." } documents = [ { "title": "Installation Guide for RHEL", "uri": "http://example.com/docs/rhel-installation", "summary": "A comprehensive guide to install RHEL.", "product": "Red Hat Enterprise Linux", "description": "This document provides step-by-step instructions." }, { "title": "Troubleshooting RHEL Installation", "uri": "http://example.com/docs/rhel-troubleshooting", "summary": "Common issues and solutions during RHEL installation.", "product": "Red Hat Enterprise Linux", "description": "This document helps troubleshoot installation problems." }, { "title": "RHEL Support", "uri": "http://example.com/docs/rhel-support", "summary": "Support options for RHEL users.", "product": "Red Hat Enterprise Linux", "description": "Information on how to get support for RHEL." } ] # Call the rerank function try: results = rerank(server_address, port, query, documents, top_k=10) print("Reranked Results:") for uri, score in results: print(f"Document URI: {uri}, Score: {score}") except Exception as e: print(f"An error occurred: {e}")