diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 903953486..879c9df69 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,8 +1,7 @@ -from langchain.schema import Document - from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import DatasetQuery, DocumentSegment from models.model import DatasetRetrieverResource diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f11170d88..c8a2e0944 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,7 +9,6 @@ from typing import Optional, cast from flask import Flask, current_app from flask_login import current_user -from langchain.text_splitter import TextSplitter from sqlalchemy.orm.exc import ObjectDeletedError from core.docstore.dataset_docstore import DatasetDocumentStore @@ -24,6 +23,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter +from core.splitter.text_splitter import TextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py index 284acaf1b..71297588a 100644 --- a/api/core/rag/data_post_processor/reorder.py +++ b/api/core/rag/data_post_processor/reorder.py @@ -1,5 +1,4 @@ - -from langchain.schema import Document +from core.rag.models.document import Document class ReorderRunner: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 5fc60ba80..f28436ffd 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -5,9 +5,9 @@ from typing import Any, Optional import requests from flask import current_app from flask_login import current_user -from langchain.schema import Document from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Document as DocumentModel from models.source import DataSourceBinding diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 2ad9718ec..fcb06e5c8 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -2,12 +2,11 @@ from abc import ABC, abstractmethod from typing import Optional -from langchain.text_splitter import TextSplitter - from core.model_manager import ModelInstance from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.models.document import Document from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter +from core.splitter.text_splitter import TextSplitter from models.dataset import Dataset, DatasetProcessRule diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 024a18bad..221318c2c 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -1,4 +1,6 @@ -from typing import Optional +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any, Optional from pydantic import BaseModel, Field @@ -14,3 +16,64 @@ class Document(BaseModel): metadata: Optional[dict] = Field(default_factory=dict) +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + @abstractmethod + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ diff --git a/api/core/rerank/rerank.py b/api/core/rerank/rerank.py index a675dfc56..7000f4e0a 100644 --- a/api/core/rerank/rerank.py +++ b/api/core/rerank/rerank.py @@ -1,8 +1,7 @@ from typing import Optional -from langchain.schema import Document - from core.model_manager import ModelInstance +from core.rag.models.document import Document class RerankRunner: diff --git a/api/core/splitter/fixed_text_splitter.py b/api/core/splitter/fixed_text_splitter.py index 285a7ba14..a3ce72e6c 100644 --- a/api/core/splitter/fixed_text_splitter.py +++ b/api/core/splitter/fixed_text_splitter.py @@ -3,7 +3,10 @@ from __future__ import annotations from typing import Any, Optional, cast -from langchain.text_splitter import ( +from core.model_manager import ModelInstance +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.splitter.text_splitter import ( TS, AbstractSet, Collection, @@ -14,10 +17,6 @@ from langchain.text_splitter import ( Union, ) -from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer - class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ diff --git a/api/core/splitter/text_splitter.py b/api/core/splitter/text_splitter.py new file mode 100644 index 000000000..e3d43c065 --- /dev/null +++ b/api/core/splitter/text_splitter.py @@ -0,0 +1,903 @@ +from __future__ import annotations + +import copy +import logging +import re +from abc import ABC, abstractmethod +from collections.abc import Callable, Collection, Iterable, Sequence, Set +from dataclasses import dataclass +from enum import Enum +from typing import ( + Any, + Literal, + Optional, + TypedDict, + TypeVar, + Union, +) + +from core.rag.models.document import BaseDocumentTransformer, Document + +logger = logging.getLogger(__name__) + +TS = TypeVar("TS", bound="TextSplitter") + + +def _split_text_with_regex( + text: str, separator: str, keep_separator: bool +) -> list[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class TextSplitter(BaseDocumentTransformer, ABC): + """Interface for splitting text into chunks.""" + + def __init__( + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + ) -> None: + """Create a new TextSplitter. + + Args: + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + length_function: Function that measures the length of given chunks + keep_separator: Whether to keep the separator in the chunks + add_start_index: If `True`, includes chunk's start index in metadata + """ + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + + @abstractmethod + def split_text(self, text: str) -> list[str]: + """Split text into multiple components.""" + + def create_documents( + self, texts: list[str], metadatas: Optional[list[dict]] = None + ) -> list[Document]: + """Create documents from a list of texts.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + index = -1 + for chunk in self.split_text(text): + metadata = copy.deepcopy(_metadatas[i]) + if self._add_start_index: + index = text.find(chunk, index + 1) + metadata["start_index"] = index + new_doc = Document(page_content=chunk, metadata=metadata) + documents.append(new_doc) + return documents + + def split_documents(self, documents: Iterable[Document]) -> list[Document]: + """Split documents.""" + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) + return self.create_documents(texts, metadatas=metadatas) + + def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: + text = separator.join(docs) + text = text.strip() + if text == "": + return None + else: + return text + + def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: + # We now want to combine these smaller pieces into medium size + # chunks to send to the LLM. + separator_len = self._length_function(separator) + + docs = [] + current_doc: list[str] = [] + total = 0 + for d in splits: + _len = self._length_function(d) + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 + ): + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + current_doc.append(d) + total += _len + (separator_len if len(current_doc) > 1 else 0) + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + return docs + + @classmethod + def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: + """Text splitter that uses HuggingFace tokenizer to count length.""" + try: + from transformers import PreTrainedTokenizerBase + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "Tokenizer received was not an instance of PreTrainedTokenizerBase" + ) + + def _huggingface_tokenizer_length(text: str) -> int: + return len(tokenizer.encode(text)) + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + return cls(length_function=_huggingface_tokenizer_length, **kwargs) + + @classmethod + def from_tiktoken_encoder( + cls: type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> TS: + """Text splitter that uses tiktoken encoder to count length.""" + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate max_tokens_for_prompt. " + "Please install it with `pip install tiktoken`." + ) + + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) + + def _tiktoken_encoder(text: str) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model_name": model_name, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + + return cls(length_function=_tiktoken_encoder, **kwargs) + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform sequence of documents by splitting them.""" + return self.split_documents(list(documents)) + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a sequence of documents by splitting them.""" + raise NotImplementedError + + +class CharacterTextSplitter(TextSplitter): + """Splitting text that looks at characters.""" + + def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + splits = _split_text_with_regex(text, self._separator, self._keep_separator) + _separator = "" if self._keep_separator else self._separator + return self._merge_splits(splits, _separator) + + +class LineType(TypedDict): + """Line type as typed dict.""" + + metadata: dict[str, str] + content: str + + +class HeaderType(TypedDict): + """Header type as typed dict.""" + + level: int + name: str + data: str + + +class MarkdownHeaderTextSplitter: + """Splitting markdown files based on specified headers.""" + + def __init__( + self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False + ): + """Create a new MarkdownHeaderTextSplitter. + + Args: + headers_to_split_on: Headers we want to track + return_each_line: Return each line w/ associated headers + """ + # Output line-by-line or aggregated into chunks w/ common headers + self.return_each_line = return_each_line + # Given the headers we want to split on, + # (e.g., "#, ##, etc") order by length + self.headers_to_split_on = sorted( + headers_to_split_on, key=lambda split: len(split[0]), reverse=True + ) + + def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: + """Combine lines with common metadata into chunks + Args: + lines: Line of text / associated header metadata + """ + aggregated_chunks: list[LineType] = [] + + for line in lines: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == line["metadata"] + ): + # If the last line in the aggregated list + # has the same metadata as the current line, + # append the current content to the last lines's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + else: + # Otherwise, append the current line to the aggregated list + aggregated_chunks.append(line) + + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in aggregated_chunks + ] + + def split_text(self, text: str) -> list[Document]: + """Split markdown file + Args: + text: Markdown file""" + + # Split the input text by newline character ("\n"). + lines = text.split("\n") + # Final output + lines_with_metadata: list[LineType] = [] + # Content and metadata of the chunk currently being processed + current_content: list[str] = [] + current_metadata: dict[str, str] = {} + # Keep track of the nested header structure + # header_stack: List[Dict[str, Union[int, str]]] = [] + header_stack: list[HeaderType] = [] + initial_metadata: dict[str, str] = {} + + for line in lines: + stripped_line = line.strip() + # Check each line against each of the header types (e.g., #, ##) + for sep, name in self.headers_to_split_on: + # Check if line starts with a header that we intend to split on + if stripped_line.startswith(sep) and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ): + # Ensure we are tracking the header as metadata + if name is not None: + # Get the current header level + current_header_level = sep.count("#") + + # Pop out headers of lower or same level from the stack + while ( + header_stack + and header_stack[-1]["level"] >= current_header_level + ): + # We have encountered a new header + # at the same or higher level + popped_header = header_stack.pop() + # Clear the metadata for the + # popped header in initial_metadata + if popped_header["name"] in initial_metadata: + initial_metadata.pop(popped_header["name"]) + + # Push the current header to the stack + header: HeaderType = { + "level": current_header_level, + "name": name, + "data": stripped_line[len(sep):].strip(), + } + header_stack.append(header) + # Update initial_metadata with the current header + initial_metadata[name] = header["data"] + + # Add the previous line to the lines_with_metadata + # only if current_content is not empty + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + break + else: + if stripped_line: + current_content.append(stripped_line) + elif current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + current_metadata = initial_metadata.copy() + + if current_content: + lines_with_metadata.append( + {"content": "\n".join(current_content), "metadata": current_metadata} + ) + + # lines_with_metadata has each line with associated header metadata + # aggregate these into chunks based on common metadata + if not self.return_each_line: + return self.aggregate_lines_to_chunks(lines_with_metadata) + else: + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in lines_with_metadata + ] + + +# should be in newer Python versions (3.10+) +# @dataclass(frozen=True, kw_only=True, slots=True) +@dataclass(frozen=True) +class Tokenizer: + chunk_overlap: int + tokens_per_chunk: int + decode: Callable[[list[int]], str] + encode: Callable[[str], list[int]] + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: list[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + +class TokenTextSplitter(TextSplitter): + """Splitting text to tokens using model tokenizer.""" + + def __init__( + self, + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please install it with `pip install tiktoken`." + ) + + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special + self._disallowed_special = disallowed_special + + def split_text(self, text: str) -> list[str]: + def _encode(_text: str) -> list[int]: + return self._tokenizer.encode( + _text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=_encode, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class Language(str, Enum): + """Enum of the programming languages.""" + + CPP = "cpp" + GO = "go" + JAVA = "java" + JS = "js" + PHP = "php" + PROTO = "proto" + PYTHON = "python" + RST = "rst" + RUBY = "ruby" + RUST = "rust" + SCALA = "scala" + SWIFT = "swift" + MARKDOWN = "markdown" + LATEX = "latex" + HTML = "html" + SOL = "sol" + + +class RecursiveCharacterTextSplitter(TextSplitter): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one + that works. + """ + + def __init__( + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or ["\n\n", "\n", " ", ""] + + def _split_text(self, text: str, separators: list[str]) -> list[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + if _s == "": + separator = _s + break + if re.search(_s, text): + separator = _s + new_separators = separators[i + 1:] + break + + splits = _split_text_with_regex(text, separator, self._keep_separator) + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text: str) -> list[str]: + return self._split_text(text, self._separators) + + @classmethod + def from_language( + cls, language: Language, **kwargs: Any + ) -> RecursiveCharacterTextSplitter: + separators = cls.get_separators_for_language(language) + return cls(separators=separators, **kwargs) + + @staticmethod + def get_separators_for_language(language: Language) -> list[str]: + if language == Language.CPP: + return [ + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.GO: + return [ + # Split along function definitions + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JAVA: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JS: + return [ + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PHP: + return [ + # Split along function definitions + "\nfunction ", + # Split along class definitions + "\nclass ", + # Split along control flow statements + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PROTO: + return [ + # Split along message definitions + "\nmessage ", + # Split along service definitions + "\nservice ", + # Split along enum definitions + "\nenum ", + # Split along option definitions + "\noption ", + # Split along import statements + "\nimport ", + # Split along syntax declarations + "\nsyntax ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PYTHON: + return [ + # First, try to split along class definitions + "\nclass ", + "\ndef ", + "\n\tdef ", + # Now split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RST: + return [ + # Split along section titles + "\n=+\n", + "\n-+\n", + "\n\*+\n", + # Split along directive markers + "\n\n.. *\n\n", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUBY: + return [ + # Split along method definitions + "\ndef ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUST: + return [ + # Split along function definitions + "\nfn ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nwhile ", + "\nfor ", + "\nloop ", + "\nmatch ", + "\nconst ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SCALA: + return [ + # Split along class definitions + "\nclass ", + "\nobject ", + # Split along method definitions + "\ndef ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SWIFT: + return [ + # Split along function definitions + "\nfunc ", + # Split along class definitions + "\nclass ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.MARKDOWN: + return [ + # First, try to split along Markdown headings (starting with level 2) + "\n#{1,6} ", + # Note the alternative syntax for headings (below) is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n", + # Horizontal lines + "\n\*\*\*+\n", + "\n---+\n", + "\n___+\n", + # Note that this splitter doesn't handle horizontal lines defined + # by *three or more* of ***, ---, or ___, but this is not handled + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.LATEX: + return [ + # First, try to split along Latex sections + "\n\\\chapter{", + "\n\\\section{", + "\n\\\subsection{", + "\n\\\subsubsection{", + # Now split by environments + "\n\\\begin{enumerate}", + "\n\\\begin{itemize}", + "\n\\\begin{description}", + "\n\\\begin{list}", + "\n\\\begin{quote}", + "\n\\\begin{quotation}", + "\n\\\begin{verse}", + "\n\\\begin{verbatim}", + # Now split by math environments + "\n\\\begin{align}", + "$$", + "$", + # Now split by the normal type of lines + " ", + "", + ] + elif language == Language.HTML: + return [ + # First, try to split along HTML tags + "