mirror of
https://gitee.com/baidu/amis.git
synced 2024-11-30 02:58:05 +08:00
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
import sys
|
|
import os
|
|
import glob
|
|
import uuid
|
|
import pickle
|
|
from embedding import get_embedding
|
|
from split_markdown import split_markdown
|
|
from vector_store import get_client
|
|
|
|
chroma_client = get_client()
|
|
# 每次执行都会清理避免重复
|
|
chroma_client.reset()
|
|
collection = chroma_client.create_collection(name="amis")
|
|
|
|
doc_dir = sys.argv[1]
|
|
|
|
# 存储所有文本段用于大模型生成
|
|
text_blocks_by_id = {}
|
|
|
|
# embedding 缓存,虽然目前速度很快,但后续如果换成网络请求会比较慢
|
|
embedding_cache = {}
|
|
|
|
embedding_cache_file = os.path.join(
|
|
os.path.dirname(__file__), 'embedding.pickle')
|
|
|
|
if os.path.exists(embedding_cache_file):
|
|
with open(embedding_cache_file, 'rb') as f:
|
|
embedding_cache = pickle.load(f)
|
|
|
|
|
|
def get_embedding_with_cache(text):
|
|
if text in embedding_cache:
|
|
return embedding_cache[text]
|
|
else:
|
|
embedding = get_embedding(text).tolist()
|
|
embedding_cache[text] = embedding
|
|
return embedding
|
|
|
|
|
|
for filename in glob.iglob(doc_dir + '**/*.md', recursive=True):
|
|
with open(filename) as f:
|
|
content = f.read()
|
|
md_blocks = split_markdown(content, filename)
|
|
embeddings = []
|
|
documents = []
|
|
metadatas = []
|
|
ids = []
|
|
for block in md_blocks:
|
|
block_id = uuid.uuid4().hex
|
|
text_blocks_by_id[block_id] = block
|
|
text_blocks = block.get_text_blocks()
|
|
index = 1
|
|
for text_block in text_blocks:
|
|
embeddings.append(get_embedding_with_cache(text_block))
|
|
documents.append(text_block)
|
|
ids.append(block_id + "_" + str(index))
|
|
metadatas.append({"source": block.file_name})
|
|
index += 1
|
|
|
|
collection.add(
|
|
embeddings=embeddings,
|
|
documents=documents,
|
|
metadatas=metadatas,
|
|
ids=ids
|
|
)
|
|
|
|
|
|
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'wb') as f:
|
|
pickle.dump(text_blocks_by_id, f, pickle.HIGHEST_PROTOCOL)
|
|
|
|
with open(embedding_cache_file, 'wb') as f:
|
|
pickle.dump(embedding_cache, f, pickle.HIGHEST_PROTOCOL)
|