Skip to content

Instantly share code, notes, and snippets.

@ColeMurray
Created September 8, 2024 01:22
Show Gist options
  • Select an option

  • Save ColeMurray/1e6bc35f0c0bb46fd87a995f77d741dd to your computer and use it in GitHub Desktop.

Select an option

Save ColeMurray/1e6bc35f0c0bb46fd87a995f77d741dd to your computer and use it in GitHub Desktop.

Revisions

  1. ColeMurray created this gist Sep 8, 2024.
    121 changes: 121 additions & 0 deletions open-hands-demo-faiss-server.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,121 @@

    from fastapi import FastAPI, HTTPException
    from pydantic import BaseModel
    from typing import List, Optional, Dict
    import faiss
    import numpy as np
    import os

    app = FastAPI()

    class QueryRequest(BaseModel):
    query: List[float]
    namespace: str
    identifier: str
    num_results: Optional[int] = 10

    import pickle

    class VectorDatabase:
    def __init__(self, base_path: str = "./vector_dbs"):
    self.base_path = base_path
    self.databases: Dict[str, faiss.IndexFlatL2] = {}
    self.vectors: Dict[str, List[np.ndarray]] = {}
    self.ids: Dict[str, List[int]] = {}
    os.makedirs(base_path, exist_ok=True)
    self.load_all_dbs()

    def db_path(self, namespace: str, identifier: str) -> str:
    return os.path.join(self.base_path, f"{namespace}_{identifier}.index")

    def metadata_path(self, namespace: str, identifier: str) -> str:
    return os.path.join(self.base_path, f"{namespace}_{identifier}.metadata")

    def load_all_dbs(self):
    for filename in os.listdir(self.base_path):
    if filename.endswith(".index"):
    namespace, identifier = filename[:-6].split("_", 1)
    self.load_db(namespace, identifier)

    def load_db(self, namespace: str, identifier: str):
    db_key = f"{namespace}_{identifier}"
    index_path = self.db_path(namespace, identifier)
    metadata_path = self.metadata_path(namespace, identifier)
    print(f"Checking index path: {index_path}")
    print(f"Checking metadata path: {metadata_path}")
    if os.path.exists(index_path):
    self.databases[db_key] = faiss.read_index(index_path)
    if os.path.exists(metadata_path):
    with open(metadata_path, 'rb') as f:
    metadata = pickle.load(f)
    self.vectors[db_key] = metadata['vectors']
    self.ids[db_key] = metadata['ids']
    print(f"Metadata file contents: {metadata}")
    else:
    print("Metadata file does not exist")
    self.vectors[db_key] = []
    self.ids[db_key] = []
    print(f"Loaded database: {db_key}")
    print(f"Number of vectors: {self.databases[db_key].ntotal}")
    print(f"Number of IDs: {len(self.ids[db_key])}")
    print(f"IDs: {self.ids[db_key]}")
    else:
    print(f"Index file does not exist: {index_path}")

    def save_db(self, namespace: str, identifier: str):
    db_key = f"{namespace}_{identifier}"
    if db_key in self.databases:
    index_path = self.db_path(namespace, identifier)
    metadata_path = self.metadata_path(namespace, identifier)
    faiss.write_index(self.databases[db_key], index_path)
    with open(metadata_path, 'wb') as f:
    pickle.dump({'vectors': self.vectors[db_key], 'ids': self.ids[db_key]}, f)
    print(f"Saved database: {db_key}")

    def create_or_get_db(self, namespace: str, identifier: str, dim: int):
    db_key = f"{namespace}_{identifier}"
    if db_key not in self.databases:
    self.databases[db_key] = faiss.IndexFlatL2(dim)
    self.vectors[db_key] = []
    self.ids[db_key] = []
    return self.databases[db_key]

    def add_vector(self, namespace: str, identifier: str, vector: List[float], id: int):
    db_key = f"{namespace}_{identifier}"
    db = self.create_or_get_db(namespace, identifier, len(vector))
    np_vector = np.array([vector], dtype=np.float32)
    db.add(np_vector)
    self.vectors[db_key].append(np_vector)
    self.ids[db_key].append(id)
    self.save_db(namespace, identifier)

    def query(self, namespace: str, identifier: str, query_vector: List[float], k: int):
    db_key = f"{namespace}_{identifier}"
    if db_key not in self.databases:
    raise HTTPException(status_code=404, detail="Database not found")
    db = self.databases[db_key]
    np_query = np.array([query_vector], dtype=np.float32)
    distances, indices = db.search(np_query, k)
    results = []
    for idx, score in zip(indices[0], distances[0]):
    if idx < len(self.ids[db_key]):
    results.append({"id": self.ids[db_key][idx], "score": float(score)})
    else:
    results.append({"id": None, "score": float(score)})
    return results

    vector_db = VectorDatabase()

    @app.post("/query")
    def query_vector_db(request: QueryRequest):
    results = vector_db.query(request.namespace, request.identifier, request.query, request.num_results)
    return {"results": results}

    @app.post("/add_vector")
    def add_vector(namespace: str, identifier: str, vector: List[float], id: int):
    vector_db.add_vector(namespace, identifier, vector, id)
    return {"message": "Vector added successfully"}

    if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)