mirror of
https://gitee.com/baidu/amis.git
synced 2024-12-03 12:38:53 +08:00
112 lines
3.5 KiB
Python
112 lines
3.5 KiB
Python
|
from vector_store import get_client
|
|||
|
from split_markdown import split_markdown
|
|||
|
from embedding import get_embedding
|
|||
|
import gradio as gr
|
|||
|
import os
|
|||
|
import pickle
|
|||
|
from llm.wenxin import Wenxin, ModelName
|
|||
|
from dotenv import load_dotenv
|
|||
|
load_dotenv()
|
|||
|
|
|||
|
|
|||
|
chroma_client = get_client()
|
|||
|
collection = chroma_client.get_collection(name="amis")
|
|||
|
|
|||
|
wenxin = Wenxin()
|
|||
|
|
|||
|
text_blocks_by_id = {}
|
|||
|
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'rb') as f:
|
|||
|
text_blocks_by_id = pickle.load(f)
|
|||
|
|
|||
|
|
|||
|
def get_prompt(context, query):
|
|||
|
return f"""
|
|||
|
请只根据下面的资料回答问题,如果无法根据这些资料回答,回答“找不到相关答案”:
|
|||
|
|
|||
|
资料:
|
|||
|
{context}
|
|||
|
|
|||
|
问题是:{query}
|
|||
|
回答:"""
|
|||
|
|
|||
|
|
|||
|
def get_context(search_result, include_code=True, max_length=1024):
|
|||
|
context = ""
|
|||
|
doc_ids = []
|
|||
|
for doc_id in search_result['ids'][0]:
|
|||
|
doc_id = doc_id.split("_")[0]
|
|||
|
if doc_id not in doc_ids:
|
|||
|
doc_ids.append(doc_id)
|
|||
|
|
|||
|
for doc_id in doc_ids:
|
|||
|
markdown_block = text_blocks_by_id[doc_id]
|
|||
|
block_text = markdown_block.gen_text(512, include_code)
|
|||
|
if (len(context) + len(block_text)) < max_length:
|
|||
|
context += block_text + "\n\n"
|
|||
|
|
|||
|
return context
|
|||
|
|
|||
|
|
|||
|
query = gr.Textbox(label="问题")
|
|||
|
include_code = gr.Checkbox(value=True, label="提示词中是否要包含 amis schema",
|
|||
|
info="包含的好处是大模型会返回 json,但也会导致内容太长,只能提供少量段落给大模型,导致错过重要资料")
|
|||
|
n_result = gr.Number(
|
|||
|
value=10, precision=0, label="向量搜索查询返回个数")
|
|||
|
|
|||
|
|
|||
|
bot_result = gr.Textbox(label="文心的回答")
|
|||
|
bot_turbo_result = gr.Textbox(label="文心 Turbo 的回答")
|
|||
|
booomz_result = gr.Textbox(label="开源 BLOOMZ 的回答")
|
|||
|
prompt = gr.Textbox(label="提示词")
|
|||
|
vector_search_result = gr.Dataframe(
|
|||
|
label="向量相关搜索结果,这个结果只是为了辅助调试,确认是因为没找到相关内容还是大模型没能理解",
|
|||
|
headers=["相关段落", "所属文档"],
|
|||
|
datatype=["str", "str"],
|
|||
|
col_count=(2, "dynamic"),
|
|||
|
wrap=True
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def amis_search(query, n_result=10, include_code=True):
|
|||
|
if query.strip() == "":
|
|||
|
return "必须有输入", "", "", []
|
|||
|
|
|||
|
search_result = collection.query(
|
|||
|
query_embeddings=get_embedding(query).tolist(),
|
|||
|
n_results=n_result
|
|||
|
)
|
|||
|
|
|||
|
context = get_context(search_result, include_code)
|
|||
|
|
|||
|
if (context == ""):
|
|||
|
return "检索不到相关内容", "", "", []
|
|||
|
|
|||
|
prompt = get_prompt(context, query)
|
|||
|
bot_result = wenxin.generate(prompt, ModelName.ERNIE_BOT)
|
|||
|
# bloomz_result = wenxin.generate(prompt, ModelName.BLOOMZ_7B)
|
|||
|
markdown_blocks = []
|
|||
|
index = 0
|
|||
|
for doc in search_result['documents'][0]:
|
|||
|
markdown_block = []
|
|||
|
markdown_block.append(doc)
|
|||
|
if index < len(search_result['metadatas'][0]):
|
|||
|
source = search_result['metadatas'][0][index]['source'].replace(
|
|||
|
'docs/zh-CN/', '')
|
|||
|
markdown_block.append(
|
|||
|
source)
|
|||
|
else:
|
|||
|
print("index out of range", doc)
|
|||
|
|
|||
|
markdown_blocks.append(markdown_block)
|
|||
|
index += 1
|
|||
|
|
|||
|
return bot_result, prompt, markdown_blocks
|
|||
|
|
|||
|
|
|||
|
demo = gr.Interface(amis_search, title="amis 文档问答机器人", inputs=[
|
|||
|
query, n_result, include_code], outputs=[bot_result, prompt, vector_search_result])
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
demo.queue(concurrency_count=10).launch(share=False, server_name="0.0.0.0")
|