amis2/scripts/bot/split_markdown.py

211 lines
6.3 KiB
Python
Raw Normal View History

import sys
import re
from enum import Enum
from dataclasses import dataclass, field
from markdown import Markdown
from io import StringIO
def unmark_element(element, stream=None):
if stream is None:
stream = StringIO()
if element.text:
stream.write(element.text)
for sub in element:
unmark_element(sub, stream)
if element.tail:
stream.write(element.tail)
return stream.getvalue()
# patching Markdown
Markdown.output_formats["plain"] = unmark_element
__md = Markdown(output_format="plain")
__md.stripTopLevelTags = False
def unmark(text):
"""去掉文本样式,参考 https://stackoverflow.com/questions/761824/python-how-to-convert-markdown-formatted-text-to-text"""
return __md.convert(text)
class ContentType(Enum):
Text = 1
Code = 2
# 最长段落
LONG_CONTENT_LENGTH = 20
@dataclass
class BlockContent:
"""文本段中的文本或代码"""
type: ContentType
text: str
@dataclass
class MarkdownBlock:
"""文档段,这个是给大模型上下文的最小单位"""
# 文件名
file_name: str
# 文件标题
title: str = ""
# 二级或三级标题
header: str = ""
# 内容,可能是文本或代码段
content: list[ContentType] = field(default_factory=list)
def gen_text(self, max_length: int = 500, include_code=True) -> str:
""""输出文本"""
current_length = 0
output = self.header + "\n\n" if self.header else ""
for para in self.content:
content = para.text
# 超过长度限制了就中断,这里其实没考虑代码段 ``` 多出来的 10 个字符
if current_length + len(content) > max_length:
break
if para.type == ContentType.Code and include_code:
output += f"\n```\n{content}\n```\n"
else:
output += content + "\n"
current_length += len(content)
return output
def get_text_blocks(self) -> list[str]:
"""获取用于生成嵌入的文本段落列表"""
blocks: list[str] = []
header = self.header.replace("#", "") if self.header else ""
if header != "":
if len(header) < 4:
blocks.append(self.title + header)
else:
blocks.append(header)
all_text = ""
for para in self.content:
if para.type == ContentType.Text:
# 去掉各种样式及图片避免影响
text = unmark(para.text)
all_text += text
blocks.append(self.title + header + text)
blocks.append(text)
# 对于太长的段落,拆分一下
if len(text) > LONG_CONTENT_LENGTH:
for line in text.split(""):
blocks.append(line)
if len(all_text) < LONG_CONTENT_LENGTH:
blocks.append(header + all_text)
# 删掉重复的和避免空字符
output_blocks = set()
for block in blocks:
block = block.strip()
if block != "" and block not in output_blocks:
output_blocks.add(block)
return list(output_blocks)
def split_markdown(markdown_text: str, file_name: str) -> list[MarkdownBlock]:
"""
拆分 Markdown 文档为段落
"""
markdown_text = markdown_text.replace("\r\n", "\n").replace("\r", "\n")
# 文档标题
title = ""
lines = markdown_text.split("\n")
# markdown 段落
blocks: list[MarkdownBlock] = []
# 当前二级标题
current_header = None
current_content: list[BlockContent] = []
# 代码需要合并到一起,所以先收集
current_code: list[str] = []
# 是否在代码快中
in_code_block = False
# 文档元数据
in_meta = False
for line in lines:
# 处理文档元数据
if line.startswith("---"):
in_meta = not in_meta
continue
if in_meta and ":" in line:
key, value = line.split(":")
if key == "title":
title = value.strip()
continue
# 这是版本说明,没什么用
if line.startswith("> ") and "以上版本" in line:
continue
if line.startswith(">"):
line = line.replace(">", "")
if line.strip() == "":
continue
header_match = re.match(r"^#+\s", line)
# 匹配到了标题
if header_match:
# 如果之前有标题,那么这就是新的一段
if current_header is not None:
# 至少要有内容或者代码块
if len(current_content) > 0:
blocks.append(MarkdownBlock(file_name, title,
current_header, current_content))
current_content = []
current_code = []
# 开启新段落解析
current_header = line
else:
# 说明是刚开始的文本,没有标题
if current_header is None:
current_content.append(BlockContent(ContentType.Text, line))
blocks.append(MarkdownBlock(file_name, title,
current_header, current_content))
current_content = []
else:
# 说明是代码块
if line.startswith("```"):
in_code_block = not in_code_block
if not in_code_block:
current_content.append(BlockContent(
ContentType.Code, "\n".join(current_code)))
current_code = []
else:
if in_code_block:
current_code.append(line)
else:
current_content.append(
BlockContent(ContentType.Text, line))
if len(current_content) > 0 or len(current_code) > 0:
blocks.append(MarkdownBlock(file_name, title,
current_header, current_content))
return blocks
def test(file_name: str):
with open(file_name) as f:
content = f.read()
blocks = split_markdown(content, file_name)
for block in blocks:
print(block.getText())
if __name__ == '__main__':
test(sys.argv[1])