from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict import faiss import numpy as np import os app = FastAPI() class QueryRequest(BaseModel): query: List[float] namespace: str identifier: str num_results: Optional[int] = 10 import pickle class VectorDatabase: def __init__(self, base_path: str = "./vector_dbs"): self.base_path = base_path self.databases: Dict[str, faiss.IndexFlatL2] = {} self.vectors: Dict[str, List[np.ndarray]] = {} self.ids: Dict[str, List[int]] = {} os.makedirs(base_path, exist_ok=True) self.load_all_dbs() def db_path(self, namespace: str, identifier: str) -> str: return os.path.join(self.base_path, f"{namespace}_{identifier}.index") def metadata_path(self, namespace: str, identifier: str) -> str: return os.path.join(self.base_path, f"{namespace}_{identifier}.metadata") def load_all_dbs(self): for filename in os.listdir(self.base_path): if filename.endswith(".index"): namespace, identifier = filename[:-6].split("_", 1) self.load_db(namespace, identifier) def load_db(self, namespace: str, identifier: str): db_key = f"{namespace}_{identifier}" index_path = self.db_path(namespace, identifier) metadata_path = self.metadata_path(namespace, identifier) print(f"Checking index path: {index_path}") print(f"Checking metadata path: {metadata_path}") if os.path.exists(index_path): self.databases[db_key] = faiss.read_index(index_path) if os.path.exists(metadata_path): with open(metadata_path, 'rb') as f: metadata = pickle.load(f) self.vectors[db_key] = metadata['vectors'] self.ids[db_key] = metadata['ids'] print(f"Metadata file contents: {metadata}") else: print("Metadata file does not exist") self.vectors[db_key] = [] self.ids[db_key] = [] print(f"Loaded database: {db_key}") print(f"Number of vectors: {self.databases[db_key].ntotal}") print(f"Number of IDs: {len(self.ids[db_key])}") print(f"IDs: {self.ids[db_key]}") else: print(f"Index file does not exist: {index_path}") def save_db(self, namespace: str, identifier: str): db_key = f"{namespace}_{identifier}" if db_key in self.databases: index_path = self.db_path(namespace, identifier) metadata_path = self.metadata_path(namespace, identifier) faiss.write_index(self.databases[db_key], index_path) with open(metadata_path, 'wb') as f: pickle.dump({'vectors': self.vectors[db_key], 'ids': self.ids[db_key]}, f) print(f"Saved database: {db_key}") def create_or_get_db(self, namespace: str, identifier: str, dim: int): db_key = f"{namespace}_{identifier}" if db_key not in self.databases: self.databases[db_key] = faiss.IndexFlatL2(dim) self.vectors[db_key] = [] self.ids[db_key] = [] return self.databases[db_key] def add_vector(self, namespace: str, identifier: str, vector: List[float], id: int): db_key = f"{namespace}_{identifier}" db = self.create_or_get_db(namespace, identifier, len(vector)) np_vector = np.array([vector], dtype=np.float32) db.add(np_vector) self.vectors[db_key].append(np_vector) self.ids[db_key].append(id) self.save_db(namespace, identifier) def query(self, namespace: str, identifier: str, query_vector: List[float], k: int): db_key = f"{namespace}_{identifier}" if db_key not in self.databases: raise HTTPException(status_code=404, detail="Database not found") db = self.databases[db_key] np_query = np.array([query_vector], dtype=np.float32) distances, indices = db.search(np_query, k) results = [] for idx, score in zip(indices[0], distances[0]): if idx < len(self.ids[db_key]): results.append({"id": self.ids[db_key][idx], "score": float(score)}) else: results.append({"id": None, "score": float(score)}) return results vector_db = VectorDatabase() @app.post("/query") def query_vector_db(request: QueryRequest): results = vector_db.query(request.namespace, request.identifier, request.query, request.num_results) return {"results": results} @app.post("/add_vector") def add_vector(namespace: str, identifier: str, vector: List[float], id: int): vector_db.add_vector(namespace, identifier, vector, id) return {"message": "Vector added successfully"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)