medium-ml-service / embedding_store.py
anuragsingh26's picture
initial deploy
29dfccc
import faiss
import numpy as np
import os
import pickle
class EmbeddingStore:
def __init__(self, dim: int, index_path="faiss.index", meta_path="faiss_meta.pkl"):
self.dim = dim
self.index_path = index_path
self.meta_path = meta_path
if os.path.exists(index_path) and os.path.exists(meta_path):
print("loading FAISS index from disk...")
self.index = faiss.read_index(index_path)
with open(meta_path, "rb") as f:
self.blog_ids = pickle.load(f)
else:
print("Creating new FAISS index...")
self.index = faiss.IndexFlatIP(dim)
self.blog_ids = []
def _normalize(self, vector: np.ndarray):
return vector / np.linalg.norm(vector)
def add(self, blog_id: int, vector: np.ndarray):
vector = self._normalize(vector)
self.index.add(vector.reshape(1, -1))
self.blog_ids.append(blog_id)
self.save()
def search(self, vector: np.ndarray, top_k: int):
vector = self._normalize(vector)
scores, indices = self.index.search(vector.reshape(1, -1), top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
results.append({
"blogId": self.blog_ids[idx],
"score": float(score)
})
return results
def save(self):
faiss.write_index(self.index, self.index_path)
with open(self.meta_path, "wb") as f:
pickle.dump(self.blog_ids, f)