diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 63bab8cd5..3505615a2 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -1,6 +1,7 @@ import logging from typing import List +import numpy as np from langchain.embeddings.base import Embeddings from sqlalchemy.exc import IntegrityError @@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings): embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) except Exception as ex: raise self._embeddings.handle_exceptions(ex) - i = 0 + normalized_embedding_results = [] for text in embedding_queue_texts: hash = helper.generate_text_hash(text) try: embedding = Embedding(model_name=self._embeddings.name, hash=hash) - embedding.set_embedding(embedding_results[i]) + vector = embedding_results[i] + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + normalized_embedding_results.append(normalized_embedding) + embedding.set_embedding(normalized_embedding) db.session.add(embedding) db.session.commit() except IntegrityError: @@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings): finally: i += 1 - text_embeddings.extend(embedding_results) + text_embeddings.extend(normalized_embedding_results) return text_embeddings def embed_query(self, text: str) -> List[float]: @@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings): try: embedding_results = self._embeddings.client.embed_query(text) + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() except Exception as ex: raise self._embeddings.handle_exceptions(ex) @@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings): return embedding_results -