amis/scripts/bot/gen_embedding.py

73 lines
2.1 KiB
Python
Raw Normal View History

import sys
import os
import glob
import uuid
import pickle
from embedding import get_embedding
from split_markdown import split_markdown
from vector_store import get_client
chroma_client = get_client()
# 每次执行都会清理避免重复
chroma_client.reset()
collection = chroma_client.create_collection(name="amis")
doc_dir = sys.argv[1]
# 存储所有文本段用于大模型生成
text_blocks_by_id = {}
# embedding 缓存,虽然目前速度很快,但后续如果换成网络请求会比较慢
embedding_cache = {}
embedding_cache_file = os.path.join(
os.path.dirname(__file__), 'embedding.pickle')
if os.path.exists(embedding_cache_file):
with open(embedding_cache_file, 'rb') as f:
embedding_cache = pickle.load(f)
def get_embedding_with_cache(text):
if text in embedding_cache:
return embedding_cache[text]
else:
embedding = get_embedding(text).tolist()
embedding_cache[text] = embedding
return embedding
for filename in glob.iglob(doc_dir + '**/*.md', recursive=True):
with open(filename) as f:
content = f.read()
md_blocks = split_markdown(content, filename)
embeddings = []
documents = []
metadatas = []
ids = []
for block in md_blocks:
block_id = uuid.uuid4().hex
text_blocks_by_id[block_id] = block
text_blocks = block.get_text_blocks()
index = 1
for text_block in text_blocks:
embeddings.append(get_embedding_with_cache(text_block))
documents.append(text_block)
ids.append(block_id + "_" + str(index))
metadatas.append({"source": block.file_name})
index += 1
collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids
)
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'wb') as f:
pickle.dump(text_blocks_by_id, f, pickle.HIGHEST_PROTOCOL)
with open(embedding_cache_file, 'wb') as f:
pickle.dump(embedding_cache, f, pickle.HIGHEST_PROTOCOL)