From dd34bf988360bf38b6d423211dcf2dd30fb5e211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E5=A4=9A=E7=9B=8A?= Date: Thu, 22 Feb 2024 14:45:20 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=8E=BB=E6=8E=89=20bot=20=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B=E4=B8=AD=E7=9A=84=20pickle=20=E4=BE=9D=E8=B5=96=20(#9?= =?UTF-8?q?640)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/bot/.gitignore | 4 ++-- scripts/bot/gen_embedding.py | 14 +++++++------- scripts/bot/gui.py | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/scripts/bot/.gitignore b/scripts/bot/.gitignore index 9734ab800..b29a9675b 100644 --- a/scripts/bot/.gitignore +++ b/scripts/bot/.gitignore @@ -1,7 +1,7 @@ db __pycache__ -text.pickle -embedding.pickle +text.json +embedding.json .env m3e-base flagged diff --git a/scripts/bot/gen_embedding.py b/scripts/bot/gen_embedding.py index abb02dd40..717b9019f 100644 --- a/scripts/bot/gen_embedding.py +++ b/scripts/bot/gen_embedding.py @@ -2,7 +2,7 @@ import sys import os import glob import uuid -import pickle +import json from embedding import get_embedding from split_markdown import split_markdown from vector_store import get_client @@ -21,11 +21,11 @@ text_blocks_by_id = {} embedding_cache = {} embedding_cache_file = os.path.join( - os.path.dirname(__file__), 'embedding.pickle') + os.path.dirname(__file__), 'embedding.json') if os.path.exists(embedding_cache_file): with open(embedding_cache_file, 'rb') as f: - embedding_cache = pickle.load(f) + embedding_cache = json.load(f) def get_embedding_with_cache(text): @@ -65,8 +65,8 @@ for filename in glob.iglob(doc_dir + '**/*.md', recursive=True): ) -with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'wb') as f: - pickle.dump(text_blocks_by_id, f, pickle.HIGHEST_PROTOCOL) +with open(os.path.join(os.path.dirname(__file__), 'text.json'), 'w') as f: + json.dump(text_blocks_by_id, f) -with open(embedding_cache_file, 'wb') as f: - pickle.dump(embedding_cache, f, pickle.HIGHEST_PROTOCOL) +with open(embedding_cache_file, 'w') as f: + json.dump(embedding_cache, f) diff --git a/scripts/bot/gui.py b/scripts/bot/gui.py index 8d1618eec..7cbc43819 100644 --- a/scripts/bot/gui.py +++ b/scripts/bot/gui.py @@ -3,7 +3,7 @@ from split_markdown import split_markdown from embedding import get_embedding import gradio as gr import os -import pickle +import json from llm.wenxin import Wenxin, ModelName from dotenv import load_dotenv load_dotenv() @@ -15,8 +15,8 @@ collection = chroma_client.get_collection(name="amis") wenxin = Wenxin() text_blocks_by_id = {} -with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'rb') as f: - text_blocks_by_id = pickle.load(f) +with open(os.path.join(os.path.dirname(__file__), 'text.json'), 'rb') as f: + text_blocks_by_id = json.load(f) def get_prompt(context, query):