Created
September 8, 2024 01:22
-
-
Save ColeMurray/1e6bc35f0c0bb46fd87a995f77d741dd to your computer and use it in GitHub Desktop.
Revisions
-
ColeMurray created this gist
Sep 8, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,121 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict import faiss import numpy as np import os app = FastAPI() class QueryRequest(BaseModel): query: List[float] namespace: str identifier: str num_results: Optional[int] = 10 import pickle class VectorDatabase: def __init__(self, base_path: str = "./vector_dbs"): self.base_path = base_path self.databases: Dict[str, faiss.IndexFlatL2] = {} self.vectors: Dict[str, List[np.ndarray]] = {} self.ids: Dict[str, List[int]] = {} os.makedirs(base_path, exist_ok=True) self.load_all_dbs() def db_path(self, namespace: str, identifier: str) -> str: return os.path.join(self.base_path, f"{namespace}_{identifier}.index") def metadata_path(self, namespace: str, identifier: str) -> str: return os.path.join(self.base_path, f"{namespace}_{identifier}.metadata") def load_all_dbs(self): for filename in os.listdir(self.base_path): if filename.endswith(".index"): namespace, identifier = filename[:-6].split("_", 1) self.load_db(namespace, identifier) def load_db(self, namespace: str, identifier: str): db_key = f"{namespace}_{identifier}" index_path = self.db_path(namespace, identifier) metadata_path = self.metadata_path(namespace, identifier) print(f"Checking index path: {index_path}") print(f"Checking metadata path: {metadata_path}") if os.path.exists(index_path): self.databases[db_key] = faiss.read_index(index_path) if os.path.exists(metadata_path): with open(metadata_path, 'rb') as f: metadata = pickle.load(f) self.vectors[db_key] = metadata['vectors'] self.ids[db_key] = metadata['ids'] print(f"Metadata file contents: {metadata}") else: print("Metadata file does not exist") self.vectors[db_key] = [] self.ids[db_key] = [] print(f"Loaded database: {db_key}") print(f"Number of vectors: {self.databases[db_key].ntotal}") print(f"Number of IDs: {len(self.ids[db_key])}") print(f"IDs: {self.ids[db_key]}") else: print(f"Index file does not exist: {index_path}") def save_db(self, namespace: str, identifier: str): db_key = f"{namespace}_{identifier}" if db_key in self.databases: index_path = self.db_path(namespace, identifier) metadata_path = self.metadata_path(namespace, identifier) faiss.write_index(self.databases[db_key], index_path) with open(metadata_path, 'wb') as f: pickle.dump({'vectors': self.vectors[db_key], 'ids': self.ids[db_key]}, f) print(f"Saved database: {db_key}") def create_or_get_db(self, namespace: str, identifier: str, dim: int): db_key = f"{namespace}_{identifier}" if db_key not in self.databases: self.databases[db_key] = faiss.IndexFlatL2(dim) self.vectors[db_key] = [] self.ids[db_key] = [] return self.databases[db_key] def add_vector(self, namespace: str, identifier: str, vector: List[float], id: int): db_key = f"{namespace}_{identifier}" db = self.create_or_get_db(namespace, identifier, len(vector)) np_vector = np.array([vector], dtype=np.float32) db.add(np_vector) self.vectors[db_key].append(np_vector) self.ids[db_key].append(id) self.save_db(namespace, identifier) def query(self, namespace: str, identifier: str, query_vector: List[float], k: int): db_key = f"{namespace}_{identifier}" if db_key not in self.databases: raise HTTPException(status_code=404, detail="Database not found") db = self.databases[db_key] np_query = np.array([query_vector], dtype=np.float32) distances, indices = db.search(np_query, k) results = [] for idx, score in zip(indices[0], distances[0]): if idx < len(self.ids[db_key]): results.append({"id": self.ids[db_key][idx], "score": float(score)}) else: results.append({"id": None, "score": float(score)}) return results vector_db = VectorDatabase() @app.post("/query") def query_vector_db(request: QueryRequest): results = vector_db.query(request.namespace, request.identifier, request.query, request.num_results) return {"results": results} @app.post("/add_vector") def add_vector(namespace: str, identifier: str, vector: List[float], id: int): vector_db.add_vector(namespace, identifier, vector, id) return {"message": "Vector added successfully"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)