chore: 去掉 bot 示例中的 pickle 依赖 (#9640)

This commit is contained in:
吴多益 2024-02-22 14:45:20 +08:00 committed by GitHub
parent 7e03085faf
commit dd34bf9883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 12 deletions

View File

@ -1,7 +1,7 @@
db
__pycache__
text.pickle
embedding.pickle
text.json
embedding.json
.env
m3e-base
flagged

View File

@ -2,7 +2,7 @@ import sys
import os
import glob
import uuid
import pickle
import json
from embedding import get_embedding
from split_markdown import split_markdown
from vector_store import get_client
@ -21,11 +21,11 @@ text_blocks_by_id = {}
embedding_cache = {}
embedding_cache_file = os.path.join(
os.path.dirname(__file__), 'embedding.pickle')
os.path.dirname(__file__), 'embedding.json')
if os.path.exists(embedding_cache_file):
with open(embedding_cache_file, 'rb') as f:
embedding_cache = pickle.load(f)
embedding_cache = json.load(f)
def get_embedding_with_cache(text):
@ -65,8 +65,8 @@ for filename in glob.iglob(doc_dir + '**/*.md', recursive=True):
)
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'wb') as f:
pickle.dump(text_blocks_by_id, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(os.path.dirname(__file__), 'text.json'), 'w') as f:
json.dump(text_blocks_by_id, f)
with open(embedding_cache_file, 'wb') as f:
pickle.dump(embedding_cache, f, pickle.HIGHEST_PROTOCOL)
with open(embedding_cache_file, 'w') as f:
json.dump(embedding_cache, f)

View File

@ -3,7 +3,7 @@ from split_markdown import split_markdown
from embedding import get_embedding
import gradio as gr
import os
import pickle
import json
from llm.wenxin import Wenxin, ModelName
from dotenv import load_dotenv
load_dotenv()
@ -15,8 +15,8 @@ collection = chroma_client.get_collection(name="amis")
wenxin = Wenxin()
text_blocks_by_id = {}
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'rb') as f:
text_blocks_by_id = pickle.load(f)
with open(os.path.join(os.path.dirname(__file__), 'text.json'), 'rb') as f:
text_blocks_by_id = json.load(f)
def get_prompt(context, query):