Created
April 22, 2025 04:14
-
-
Save wyfdev/7e5882555b24319d7ebfcd09bf3d5347 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Hierarchical text splitter using regex delimiters. | |
| Delimiters are specified as a list of (pattern, name) tuples, ordered from highest to lowest level. | |
| The splitter recursively divides text into chunks using these delimiters. | |
| Splitting logic: | |
| 1. Prefer splitting at the highest delimiter level, ensuring resulting chunks do not exceed chunk_size. | |
| 2. Merge adjacent chunks at the same level if their combined length is within chunk_size. | |
| 3. If a chunk remains too large, recursively split it using the next lower delimiter level. | |
| 4. When splitting, aim for balanced chunk sizes. | |
| Delimiters are regular expressions, e.g.: | |
| [ | |
| (r'^#.*$', 'title'), | |
| (r'^##.*$', 'chapter'), | |
| (r'^###.*$', 'section') | |
| ] | |
| Each tuple contains a regex pattern and a delimiter name (used for metadata; can be None). | |
| Text is split strictly at delimiters. | |
| A default ultimate delimiter r"(?<=\n)(?=[^\n])" (start of non-empty new line) is appended as a fallback. | |
| If text cannot be split by any delimiter, it remains as a single chunk, even if it exceeds chunk_size. | |
| If a chunk exceeds max_length, CanNotSplitUnderMaxLength is raised. | |
| """ | |
| from re import finditer as re_finditer, Match as ReMatch, MULTILINE as RE_MULTILINE | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Dict, NoReturn | |
| class CanNotSplitUnderMaxLength(Exception): | |
| """Raised when a chunk cannot be split under the specified max_length.""" | |
| def __init__(self, chunk_length: int, max_length: int): | |
| self.chunk_length = chunk_length | |
| self.max_length = max_length | |
| super().__init__( | |
| f"Cannot split text into chunks smaller than max_length={max_length}. " | |
| f"Chunk minimized with length={chunk_length}." | |
| ) | |
| @dataclass | |
| class DocMetadata: | |
| """ | |
| Metadata for a document chunk, including start and end indices and hierarchical location. | |
| """ | |
| start_index: int = 0 | |
| end_index: int = 0 | |
| location: Dict[str, str] = field(default_factory=dict) | |
| def __post_init__(self): | |
| """Validate metadata after initialization.""" | |
| if self.start_index < 0: | |
| raise ValueError("start_index must be non-negative") | |
| if self.end_index < self.start_index: | |
| raise ValueError("end_index must be >= start_index") | |
| def get_length(self) -> int: | |
| """Return the length of the corresponding text.""" | |
| return self.end_index - self.start_index | |
| def add_location_info(self, key: str, value: str) -> None: | |
| """Add information to the location dictionary.""" | |
| self.location[key] = value | |
| class Document: | |
| """ | |
| Represents a chunk of text and its associated metadata. | |
| """ | |
| def __init__(self, text: str, metadata: DocMetadata): | |
| """ | |
| Initialize a Document with text and associated metadata. | |
| Args: | |
| text (str): The document text. | |
| metadata (DocMetadata): Metadata for the document. | |
| """ | |
| self.text = text | |
| self.metadata = metadata | |
| # Validate that text length matches metadata | |
| if len(text) != metadata.get_length() and metadata.end_index != 0: | |
| raise ValueError( | |
| f"Text length ({len(text)}) does not match metadata length " | |
| f"({metadata.get_length()})" | |
| ) | |
| @property | |
| def length(self) -> int: | |
| """Return the length of the document text.""" | |
| return len(self.text) | |
| def __eq__(self, other): | |
| if not isinstance(other, Document): | |
| return False | |
| return self.text == other.text and self.metadata == other.metadata | |
| def __repr__(self): | |
| """ | |
| Return a string representation of the Document, showing a preview of the text and metadata. | |
| """ | |
| if len(self.text) <= 25: | |
| preview = repr(self.text) | |
| else: | |
| preview = repr(self.text[:10] + "..." + self.text[-10:]) | |
| return f"Document(text={preview}, metadata={self.metadata})" | |
| def __str__(self): | |
| """Return a string representation focusing on content.""" | |
| return f"Document({len(self.text)} chars, location={self.metadata.location})" | |
| class Splitter: | |
| """ | |
| Hierarchical text splitter that splits text into chunks using a list of regex delimiters. | |
| """ | |
| delimiter: list[tuple[str, str | None]] | |
| default_ultimate_delimiter = r"(?<=\n)(?=[^\n])" | |
| CanNotSplitUnderMaxLength = CanNotSplitUnderMaxLength | |
| def __init__( | |
| self, | |
| delimiter: list[tuple[str, str | None]], | |
| chunk_size: int = 1000, | |
| ultimate_delimiter: str | None = default_ultimate_delimiter, | |
| max_length: Optional[int] = None, | |
| ): | |
| """ | |
| Initialize the Splitter. | |
| Args: | |
| delimiter (list[tuple[str, str]]): List of (regex pattern, name) tuples for splitting. | |
| chunk_size (int): Preferred maximum chunk size. | |
| ultimate_delimiter (str): Fallback delimiter if no others are found. | |
| max_length (Optional[int]): Hard maximum chunk length. | |
| """ | |
| self.delimiter = ( | |
| delimiter.copy() | |
| ) # Make a copy to avoid modifying the input list | |
| if ultimate_delimiter is not None: | |
| self.delimiter.append((ultimate_delimiter, None)) | |
| self.chunk_size = chunk_size | |
| self.max_length = max_length | |
| def split_text(self, text: str): | |
| """ | |
| Split the input text into chunks using hierarchical delimiters. | |
| Args: | |
| text (str): The text to split. | |
| Returns: | |
| list[Document]: List of Document objects representing the chunks. | |
| """ | |
| # Start recursive splitting from the highest delimiter level (0) | |
| return self._split_recursive(text, 0, 0) | |
| def _split_recursive(self, text: str, level: int, base_index: int, location: Optional[Dict[str, str]] = None) -> list[Document]: | |
| """ | |
| Recursively split text using delimiters at the given level. | |
| Args: | |
| text (str): The text to split. | |
| level (int): Current delimiter level. | |
| base_index (int): Base index for chunk start positions. | |
| location (Dict[str, str], optional): Metadata location tracking. | |
| Returns: | |
| list[Document]: List of Document objects for the split chunks. | |
| Raises: | |
| CanNotSplitUnderMaxLength: If a chunk cannot be split under max_length. | |
| """ | |
| if location is None: | |
| location = {} | |
| # If we've exhausted all delimiters, return the text as a single chunk | |
| if level >= len(self.delimiter): | |
| # Check max_length constraint | |
| if self.max_length is not None and len(text) > self.max_length: | |
| # This might raise an exception by default if max_length is set | |
| return self.handle_exceed_max_length(text, start_index=base_index, location=location.copy()) | |
| else: | |
| return [ | |
| Document( | |
| text, | |
| DocMetadata( | |
| start_index=base_index, | |
| end_index=base_index + len(text), | |
| location=location.copy(), | |
| ), | |
| ) | |
| ] | |
| delimiter_pattern, delimiter_name = self.delimiter[level] | |
| # Use regex to find delimiter positions but do not remove them from text | |
| matches = list(re_finditer(delimiter_pattern, text, RE_MULTILINE)) | |
| split_points = [0] + [m.start() for m in matches] + [len(text)] | |
| # Build a lookup dictionary for faster match retrieval | |
| delimiter_dict = {m.start(): m for m in matches} | |
| # Try to merge parts to fit chunk_size and avoid going to lower levels if possible | |
| i = 0 | |
| while i < len(split_points) - 1: | |
| chunk_start = split_points[i] | |
| chunk_end = split_points[i + 1] | |
| chunk = text[chunk_start:chunk_end] | |
| next_i = i + 1 | |
| # Try to merge with following chunks if their combined length stays under chunk_size | |
| while ( | |
| next_i < len(split_points) - 1 | |
| and | |
| len(text[chunk_start:split_points[next_i + 1]]) <= self.chunk_size | |
| ): | |
| chunk_end = split_points[next_i + 1] | |
| chunk = text[chunk_start:chunk_end] | |
| next_i += 1 | |
| # Copy parent location data and add current level information | |
| chunk_location = location.copy() | |
| start_match: Optional[ReMatch[str]] = delimiter_dict.get(chunk_start) | |
| if start_match and delimiter_name: | |
| chunk_location[delimiter_name] = start_match.group() | |
| # Process the chunk | |
| if not chunk: | |
| # Empty chunk, skip to next | |
| pass | |
| elif len(chunk) > self.chunk_size and level < len(self.delimiter) - 1: | |
| # Chunk exceeds preferred size, split at lower level | |
| # Pass down the structural location information | |
| yield from self._split_recursive( | |
| chunk, level + 1, base_index + chunk_start, chunk_location | |
| ) | |
| else: | |
| # Check max_length constraint (hard limit) | |
| if self.max_length and len(chunk) > self.max_length: | |
| # Try to handle the oversized chunk | |
| yield from self.handle_exceed_max_length( | |
| text=chunk, | |
| start_index=base_index + chunk_start, | |
| location=chunk_location.copy(), | |
| ) | |
| else: | |
| yield Document( | |
| chunk, | |
| DocMetadata( | |
| start_index=base_index + chunk_start, | |
| end_index=base_index + chunk_end, | |
| location=chunk_location, | |
| ), | |
| ) | |
| i = next_i | |
| def handle_exceed_max_length( | |
| self, text: str, start_index: int, location: dict | |
| ) -> list[Document] | NoReturn: | |
| """ | |
| Check if text exceeds max_length and handle accordingly. | |
| Overwriting this method allows for custom handling of oversized text. | |
| for example to relax restrictions: | |
| if len(text) < self.max_length + 100: | |
| return [ | |
| Document( | |
| text, | |
| DocMetadata( | |
| start_index=start_index, | |
| end_index=start_index + len(text), | |
| location=location, | |
| ), | |
| ) | |
| ] | |
| else: | |
| raise CanNotSplitUnderMaxLength( | |
| chunk_length=len(text), max_length=self.max_length + 100 | |
| ) | |
| """ | |
| # By default, we raise an exception when text exceeds max_length | |
| # Subclasses can override this method to implement custom handling | |
| if self.max_length is None: | |
| # If max_length is None, no limit is enforced, | |
| # shell not ne here as long as caller handles it | |
| return [ | |
| Document( | |
| text, | |
| DocMetadata( | |
| start_index=start_index, | |
| end_index=start_index + len(text), | |
| location=location, | |
| ), | |
| ) | |
| ] | |
| else: | |
| # Text exceeds max_length and we can't split further | |
| raise CanNotSplitUnderMaxLength( | |
| chunk_length=len(text), max_length=self.max_length | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment