chore: 文心千帆问答机器人示例 (#7308)

This commit is contained in:
吴多益 2023-06-30 11:42:09 +08:00 committed by GitHub
parent a49dd02991
commit 55622c1f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 607 additions and 0 deletions

View File

@ -10,5 +10,9 @@ insert_final_newline = true
indent_style = space
indent_size = 2
[**.{py}]
indent_style = space
indent_size = 4
[*.md]
trim_trailing_whitespace = false

View File

@ -0,0 +1 @@
base

7
scripts/bot/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
db
__pycache__
text.pickle
embedding.pickle
.env
m3e-base
flagged

10
scripts/bot/Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM amis_bot_base:v2
WORKDIR /app
EXPOSE 7860
COPY . .
CMD [ "python", "./gui.py" ]

3
scripts/bot/README.md Normal file
View File

@ -0,0 +1,3 @@
# 基于大模型的 QA 机器人示例
需要有文心千帆 AK/SK

14
scripts/bot/app.py Normal file
View File

@ -0,0 +1,14 @@
from fastapi import FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return "use /query/{query}"
@app.get("/amis-doc-query/{query}")
def read_item(query: str):
return {"item_id": "d"}

View File

@ -0,0 +1,8 @@
FROM python:3.11.4-bookworm
WORKDIR /app
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
COPY m3e-base m3e-base

View File

@ -0,0 +1,7 @@
sentence_transformers
fastapi
markdown
python-dotenv
gradio
requests
chromadb

13
scripts/bot/embedding.py Normal file
View File

@ -0,0 +1,13 @@
import os
from sentence_transformers import SentenceTransformer, util
model_name = 'moka-ai/m3e-base'
if os.getenv('EMBEDDING_MODEL') != None:
model_name = os.getenv('EMBEDDING_MODEL')
model = SentenceTransformer(model_name)
def get_embedding(text):
return model.encode(text)

View File

@ -0,0 +1,72 @@
import sys
import os
import glob
import uuid
import pickle
from embedding import get_embedding
from split_markdown import split_markdown
from vector_store import get_client
chroma_client = get_client()
# 每次执行都会清理避免重复
chroma_client.reset()
collection = chroma_client.create_collection(name="amis")
doc_dir = sys.argv[1]
# 存储所有文本段用于大模型生成
text_blocks_by_id = {}
# embedding 缓存,虽然目前速度很快,但后续如果换成网络请求会比较慢
embedding_cache = {}
embedding_cache_file = os.path.join(
os.path.dirname(__file__), 'embedding.pickle')
if os.path.exists(embedding_cache_file):
with open(embedding_cache_file, 'rb') as f:
embedding_cache = pickle.load(f)
def get_embedding_with_cache(text):
if text in embedding_cache:
return embedding_cache[text]
else:
embedding = get_embedding(text).tolist()
embedding_cache[text] = embedding
return embedding
for filename in glob.iglob(doc_dir + '**/*.md', recursive=True):
with open(filename) as f:
content = f.read()
md_blocks = split_markdown(content, filename)
embeddings = []
documents = []
metadatas = []
ids = []
for block in md_blocks:
block_id = uuid.uuid4().hex
text_blocks_by_id[block_id] = block
text_blocks = block.get_text_blocks()
index = 1
for text_block in text_blocks:
embeddings.append(get_embedding_with_cache(text_block))
documents.append(text_block)
ids.append(block_id + "_" + str(index))
metadatas.append({"source": block.file_name})
index += 1
collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids
)
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'wb') as f:
pickle.dump(text_blocks_by_id, f, pickle.HIGHEST_PROTOCOL)
with open(embedding_cache_file, 'wb') as f:
pickle.dump(embedding_cache, f, pickle.HIGHEST_PROTOCOL)

111
scripts/bot/gui.py Normal file
View File

@ -0,0 +1,111 @@
from vector_store import get_client
from split_markdown import split_markdown
from embedding import get_embedding
import gradio as gr
import os
import pickle
from llm.wenxin import Wenxin, ModelName
from dotenv import load_dotenv
load_dotenv()
chroma_client = get_client()
collection = chroma_client.get_collection(name="amis")
wenxin = Wenxin()
text_blocks_by_id = {}
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'rb') as f:
text_blocks_by_id = pickle.load(f)
def get_prompt(context, query):
return f"""
请只根据下面的资料回答问题如果无法根据这些资料回答回答找不到相关答案
资料
{context}
问题是{query}
回答"""
def get_context(search_result, include_code=True, max_length=1024):
context = ""
doc_ids = []
for doc_id in search_result['ids'][0]:
doc_id = doc_id.split("_")[0]
if doc_id not in doc_ids:
doc_ids.append(doc_id)
for doc_id in doc_ids:
markdown_block = text_blocks_by_id[doc_id]
block_text = markdown_block.gen_text(512, include_code)
if (len(context) + len(block_text)) < max_length:
context += block_text + "\n\n"
return context
query = gr.Textbox(label="问题")
include_code = gr.Checkbox(value=True, label="提示词中是否要包含 amis schema",
info="包含的好处是大模型会返回 json但也会导致内容太长只能提供少量段落给大模型导致错过重要资料")
n_result = gr.Number(
value=10, precision=0, label="向量搜索查询返回个数")
bot_result = gr.Textbox(label="文心的回答")
bot_turbo_result = gr.Textbox(label="文心 Turbo 的回答")
booomz_result = gr.Textbox(label="开源 BLOOMZ 的回答")
prompt = gr.Textbox(label="提示词")
vector_search_result = gr.Dataframe(
label="向量相关搜索结果,这个结果只是为了辅助调试,确认是因为没找到相关内容还是大模型没能理解",
headers=["相关段落", "所属文档"],
datatype=["str", "str"],
col_count=(2, "dynamic"),
wrap=True
)
def amis_search(query, n_result=10, include_code=True):
if query.strip() == "":
return "必须有输入", "", "", []
search_result = collection.query(
query_embeddings=get_embedding(query).tolist(),
n_results=n_result
)
context = get_context(search_result, include_code)
if (context == ""):
return "检索不到相关内容", "", "", []
prompt = get_prompt(context, query)
bot_result = wenxin.generate(prompt, ModelName.ERNIE_BOT)
# bloomz_result = wenxin.generate(prompt, ModelName.BLOOMZ_7B)
markdown_blocks = []
index = 0
for doc in search_result['documents'][0]:
markdown_block = []
markdown_block.append(doc)
if index < len(search_result['metadatas'][0]):
source = search_result['metadatas'][0][index]['source'].replace(
'docs/zh-CN/', '')
markdown_block.append(
source)
else:
print("index out of range", doc)
markdown_blocks.append(markdown_block)
index += 1
return bot_result, prompt, markdown_blocks
demo = gr.Interface(amis_search, title="amis 文档问答机器人", inputs=[
query, n_result, include_code], outputs=[bot_result, prompt, vector_search_result])
if __name__ == '__main__':
demo.queue(concurrency_count=10).launch(share=False, server_name="0.0.0.0")

7
scripts/bot/llm/base.py Normal file
View File

@ -0,0 +1,7 @@
from abc import ABC, abstractmethod
class LLM(ABC):
@abstractmethod
def generate(self, prompt: str) -> str:
pass

71
scripts/bot/llm/wenxin.py Normal file
View File

@ -0,0 +1,71 @@
from .base import LLM
import requests
from functools import lru_cache
import time
import os
import json
from enum import Enum
from requests.exceptions import HTTPError
base_url = 'https://aip.baidubce.com'
class ModelName(Enum):
ERNIE_BOT = 1
ERNIE_BOT_TURBO = 2
BLOOMZ_7B = 3
def get_ttl_hash(seconds=3600):
"""缓存一小时"""
return round(time.time() / seconds)
@lru_cache(maxsize=1)
def get_token(ttl_hash=None):
"""
根据 ak/sk 获取 access_token
"""
del ttl_hash
ak = os.getenv('WENXIN_AK')
sk = os.getenv('WENXIN_SK')
url = f'{base_url}/oauth/2.0/token?grant_type=client_credentials&client_id={ak}&client_secret={sk}'
response = requests.get(url)
response.raise_for_status()
return response.json()['access_token']
def query(query: str, token: str, model_name: ModelName = ModelName.ERNIE_BOT_TURBO):
"""
参考文档
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
"""
# 目前暂时写死
user_id = os.getenv('WENXIN_USER_ID')
model_path = "eb-instant"
if model_name == ModelName.BLOOMZ_7B:
model_path = "bloomz_7b1"
elif model_name == ModelName.ERNIE_BOT:
model_path = "completions"
url = f'{base_url}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_path}?access_token={token}&user_id={user_id}'
messages = [{'role': 'user', 'content': query}]
payload = {
'messages': messages
}
headers = {
'content-type': 'application/json',
}
response = requests.post(url, data=json.dumps(payload), headers=headers)
response.raise_for_status()
return response.json()["result"]
class Wenxin(LLM):
"""文心千帆大模型"""
def generate(self, prompt: str, model_name: ModelName = ModelName.ERNIE_BOT_TURBO) -> str:
try:
return query(prompt, get_token(ttl_hash=get_ttl_hash()), model_name)
except HTTPError as e:
return f'HTTPError: {e}'

View File

@ -0,0 +1,54 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/nwind/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"No sentence-transformers model found with name /Volumes/WD/ai/models/m3e-base. Creating a new one with MEAN pooling.\n"
]
}
],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n",
"\n",
"from embedding import get_embedding\n",
"from split_markdown import split_markdown\n",
"from vector_store import get_client\n",
"\n",
"chroma_client = get_client()\n",
"collection = chroma_client.get_collection(name=\"amis\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,210 @@
import sys
import re
from enum import Enum
from dataclasses import dataclass, field
from markdown import Markdown
from io import StringIO
def unmark_element(element, stream=None):
if stream is None:
stream = StringIO()
if element.text:
stream.write(element.text)
for sub in element:
unmark_element(sub, stream)
if element.tail:
stream.write(element.tail)
return stream.getvalue()
# patching Markdown
Markdown.output_formats["plain"] = unmark_element
__md = Markdown(output_format="plain")
__md.stripTopLevelTags = False
def unmark(text):
"""去掉文本样式,参考 https://stackoverflow.com/questions/761824/python-how-to-convert-markdown-formatted-text-to-text"""
return __md.convert(text)
class ContentType(Enum):
Text = 1
Code = 2
# 最长段落
LONG_CONTENT_LENGTH = 20
@dataclass
class BlockContent:
"""文本段中的文本或代码"""
type: ContentType
text: str
@dataclass
class MarkdownBlock:
"""文档段,这个是给大模型上下文的最小单位"""
# 文件名
file_name: str
# 文件标题
title: str = ""
# 二级或三级标题
header: str = ""
# 内容,可能是文本或代码段
content: list[ContentType] = field(default_factory=list)
def gen_text(self, max_length: int = 500, include_code=True) -> str:
""""输出文本"""
current_length = 0
output = self.header + "\n\n" if self.header else ""
for para in self.content:
content = para.text
# 超过长度限制了就中断,这里其实没考虑代码段 ``` 多出来的 10 个字符
if current_length + len(content) > max_length:
break
if para.type == ContentType.Code and include_code:
output += f"\n```\n{content}\n```\n"
else:
output += content + "\n"
current_length += len(content)
return output
def get_text_blocks(self) -> list[str]:
"""获取用于生成嵌入的文本段落列表"""
blocks: list[str] = []
header = self.header.replace("#", "") if self.header else ""
if header != "":
if len(header) < 4:
blocks.append(self.title + header)
else:
blocks.append(header)
all_text = ""
for para in self.content:
if para.type == ContentType.Text:
# 去掉各种样式及图片避免影响
text = unmark(para.text)
all_text += text
blocks.append(self.title + header + text)
blocks.append(text)
# 对于太长的段落,拆分一下
if len(text) > LONG_CONTENT_LENGTH:
for line in text.split(""):
blocks.append(line)
if len(all_text) < LONG_CONTENT_LENGTH:
blocks.append(header + all_text)
# 删掉重复的和避免空字符
output_blocks = set()
for block in blocks:
block = block.strip()
if block != "" and block not in output_blocks:
output_blocks.add(block)
return list(output_blocks)
def split_markdown(markdown_text: str, file_name: str) -> list[MarkdownBlock]:
"""
拆分 Markdown 文档为段落
"""
markdown_text = markdown_text.replace("\r\n", "\n").replace("\r", "\n")
# 文档标题
title = ""
lines = markdown_text.split("\n")
# markdown 段落
blocks: list[MarkdownBlock] = []
# 当前二级标题
current_header = None
current_content: list[BlockContent] = []
# 代码需要合并到一起,所以先收集
current_code: list[str] = []
# 是否在代码快中
in_code_block = False
# 文档元数据
in_meta = False
for line in lines:
# 处理文档元数据
if line.startswith("---"):
in_meta = not in_meta
continue
if in_meta and ":" in line:
key, value = line.split(":")
if key == "title":
title = value.strip()
continue
# 这是版本说明,没什么用
if line.startswith("> ") and "以上版本" in line:
continue
if line.startswith(">"):
line = line.replace(">", "")
if line.strip() == "":
continue
header_match = re.match(r"^#+\s", line)
# 匹配到了标题
if header_match:
# 如果之前有标题,那么这就是新的一段
if current_header is not None:
# 至少要有内容或者代码块
if len(current_content) > 0:
blocks.append(MarkdownBlock(file_name, title,
current_header, current_content))
current_content = []
current_code = []
# 开启新段落解析
current_header = line
else:
# 说明是刚开始的文本,没有标题
if current_header is None:
current_content.append(BlockContent(ContentType.Text, line))
blocks.append(MarkdownBlock(file_name, title,
current_header, current_content))
current_content = []
else:
# 说明是代码块
if line.startswith("```"):
in_code_block = not in_code_block
if not in_code_block:
current_content.append(BlockContent(
ContentType.Code, "\n".join(current_code)))
current_code = []
else:
if in_code_block:
current_code.append(line)
else:
current_content.append(
BlockContent(ContentType.Text, line))
if len(current_content) > 0 or len(current_code) > 0:
blocks.append(MarkdownBlock(file_name, title,
current_header, current_content))
return blocks
def test(file_name: str):
with open(file_name) as f:
content = f.read()
blocks = split_markdown(content, file_name)
for block in blocks:
print(block.getText())
if __name__ == '__main__':
test(sys.argv[1])

View File

@ -0,0 +1,15 @@
import chromadb
import os
from chromadb.config import Settings
client = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet",
persist_directory=os.path.join(
os.path.dirname(__file__), "db")
))
def get_client():
return client
def search(query: str) -> str:
return "d"