Skip to content

Instantly share code, notes, and snippets.

@wyfdev
Created April 22, 2025 04:14
Show Gist options
  • Select an option

  • Save wyfdev/7e5882555b24319d7ebfcd09bf3d5347 to your computer and use it in GitHub Desktop.

Select an option

Save wyfdev/7e5882555b24319d7ebfcd09bf3d5347 to your computer and use it in GitHub Desktop.
"""
Hierarchical text splitter using regex delimiters.
Delimiters are specified as a list of (pattern, name) tuples, ordered from highest to lowest level.
The splitter recursively divides text into chunks using these delimiters.
Splitting logic:
1. Prefer splitting at the highest delimiter level, ensuring resulting chunks do not exceed chunk_size.
2. Merge adjacent chunks at the same level if their combined length is within chunk_size.
3. If a chunk remains too large, recursively split it using the next lower delimiter level.
4. When splitting, aim for balanced chunk sizes.
Delimiters are regular expressions, e.g.:
[
(r'^#.*$', 'title'),
(r'^##.*$', 'chapter'),
(r'^###.*$', 'section')
]
Each tuple contains a regex pattern and a delimiter name (used for metadata; can be None).
Text is split strictly at delimiters.
A default ultimate delimiter r"(?<=\n)(?=[^\n])" (start of non-empty new line) is appended as a fallback.
If text cannot be split by any delimiter, it remains as a single chunk, even if it exceeds chunk_size.
If a chunk exceeds max_length, CanNotSplitUnderMaxLength is raised.
"""
from re import finditer as re_finditer, Match as ReMatch, MULTILINE as RE_MULTILINE
from dataclasses import dataclass, field
from typing import Optional, Dict, NoReturn
class CanNotSplitUnderMaxLength(Exception):
"""Raised when a chunk cannot be split under the specified max_length."""
def __init__(self, chunk_length: int, max_length: int):
self.chunk_length = chunk_length
self.max_length = max_length
super().__init__(
f"Cannot split text into chunks smaller than max_length={max_length}. "
f"Chunk minimized with length={chunk_length}."
)
@dataclass
class DocMetadata:
"""
Metadata for a document chunk, including start and end indices and hierarchical location.
"""
start_index: int = 0
end_index: int = 0
location: Dict[str, str] = field(default_factory=dict)
def __post_init__(self):
"""Validate metadata after initialization."""
if self.start_index < 0:
raise ValueError("start_index must be non-negative")
if self.end_index < self.start_index:
raise ValueError("end_index must be >= start_index")
def get_length(self) -> int:
"""Return the length of the corresponding text."""
return self.end_index - self.start_index
def add_location_info(self, key: str, value: str) -> None:
"""Add information to the location dictionary."""
self.location[key] = value
class Document:
"""
Represents a chunk of text and its associated metadata.
"""
def __init__(self, text: str, metadata: DocMetadata):
"""
Initialize a Document with text and associated metadata.
Args:
text (str): The document text.
metadata (DocMetadata): Metadata for the document.
"""
self.text = text
self.metadata = metadata
# Validate that text length matches metadata
if len(text) != metadata.get_length() and metadata.end_index != 0:
raise ValueError(
f"Text length ({len(text)}) does not match metadata length "
f"({metadata.get_length()})"
)
@property
def length(self) -> int:
"""Return the length of the document text."""
return len(self.text)
def __eq__(self, other):
if not isinstance(other, Document):
return False
return self.text == other.text and self.metadata == other.metadata
def __repr__(self):
"""
Return a string representation of the Document, showing a preview of the text and metadata.
"""
if len(self.text) <= 25:
preview = repr(self.text)
else:
preview = repr(self.text[:10] + "..." + self.text[-10:])
return f"Document(text={preview}, metadata={self.metadata})"
def __str__(self):
"""Return a string representation focusing on content."""
return f"Document({len(self.text)} chars, location={self.metadata.location})"
class Splitter:
"""
Hierarchical text splitter that splits text into chunks using a list of regex delimiters.
"""
delimiter: list[tuple[str, str | None]]
default_ultimate_delimiter = r"(?<=\n)(?=[^\n])"
CanNotSplitUnderMaxLength = CanNotSplitUnderMaxLength
def __init__(
self,
delimiter: list[tuple[str, str | None]],
chunk_size: int = 1000,
ultimate_delimiter: str | None = default_ultimate_delimiter,
max_length: Optional[int] = None,
):
"""
Initialize the Splitter.
Args:
delimiter (list[tuple[str, str]]): List of (regex pattern, name) tuples for splitting.
chunk_size (int): Preferred maximum chunk size.
ultimate_delimiter (str): Fallback delimiter if no others are found.
max_length (Optional[int]): Hard maximum chunk length.
"""
self.delimiter = (
delimiter.copy()
) # Make a copy to avoid modifying the input list
if ultimate_delimiter is not None:
self.delimiter.append((ultimate_delimiter, None))
self.chunk_size = chunk_size
self.max_length = max_length
def split_text(self, text: str):
"""
Split the input text into chunks using hierarchical delimiters.
Args:
text (str): The text to split.
Returns:
list[Document]: List of Document objects representing the chunks.
"""
# Start recursive splitting from the highest delimiter level (0)
return self._split_recursive(text, 0, 0)
def _split_recursive(self, text: str, level: int, base_index: int, location: Optional[Dict[str, str]] = None) -> list[Document]:
"""
Recursively split text using delimiters at the given level.
Args:
text (str): The text to split.
level (int): Current delimiter level.
base_index (int): Base index for chunk start positions.
location (Dict[str, str], optional): Metadata location tracking.
Returns:
list[Document]: List of Document objects for the split chunks.
Raises:
CanNotSplitUnderMaxLength: If a chunk cannot be split under max_length.
"""
if location is None:
location = {}
# If we've exhausted all delimiters, return the text as a single chunk
if level >= len(self.delimiter):
# Check max_length constraint
if self.max_length is not None and len(text) > self.max_length:
# This might raise an exception by default if max_length is set
return self.handle_exceed_max_length(text, start_index=base_index, location=location.copy())
else:
return [
Document(
text,
DocMetadata(
start_index=base_index,
end_index=base_index + len(text),
location=location.copy(),
),
)
]
delimiter_pattern, delimiter_name = self.delimiter[level]
# Use regex to find delimiter positions but do not remove them from text
matches = list(re_finditer(delimiter_pattern, text, RE_MULTILINE))
split_points = [0] + [m.start() for m in matches] + [len(text)]
# Build a lookup dictionary for faster match retrieval
delimiter_dict = {m.start(): m for m in matches}
# Try to merge parts to fit chunk_size and avoid going to lower levels if possible
i = 0
while i < len(split_points) - 1:
chunk_start = split_points[i]
chunk_end = split_points[i + 1]
chunk = text[chunk_start:chunk_end]
next_i = i + 1
# Try to merge with following chunks if their combined length stays under chunk_size
while (
next_i < len(split_points) - 1
and
len(text[chunk_start:split_points[next_i + 1]]) <= self.chunk_size
):
chunk_end = split_points[next_i + 1]
chunk = text[chunk_start:chunk_end]
next_i += 1
# Copy parent location data and add current level information
chunk_location = location.copy()
start_match: Optional[ReMatch[str]] = delimiter_dict.get(chunk_start)
if start_match and delimiter_name:
chunk_location[delimiter_name] = start_match.group()
# Process the chunk
if not chunk:
# Empty chunk, skip to next
pass
elif len(chunk) > self.chunk_size and level < len(self.delimiter) - 1:
# Chunk exceeds preferred size, split at lower level
# Pass down the structural location information
yield from self._split_recursive(
chunk, level + 1, base_index + chunk_start, chunk_location
)
else:
# Check max_length constraint (hard limit)
if self.max_length and len(chunk) > self.max_length:
# Try to handle the oversized chunk
yield from self.handle_exceed_max_length(
text=chunk,
start_index=base_index + chunk_start,
location=chunk_location.copy(),
)
else:
yield Document(
chunk,
DocMetadata(
start_index=base_index + chunk_start,
end_index=base_index + chunk_end,
location=chunk_location,
),
)
i = next_i
def handle_exceed_max_length(
self, text: str, start_index: int, location: dict
) -> list[Document] | NoReturn:
"""
Check if text exceeds max_length and handle accordingly.
Overwriting this method allows for custom handling of oversized text.
for example to relax restrictions:
if len(text) < self.max_length + 100:
return [
Document(
text,
DocMetadata(
start_index=start_index,
end_index=start_index + len(text),
location=location,
),
)
]
else:
raise CanNotSplitUnderMaxLength(
chunk_length=len(text), max_length=self.max_length + 100
)
"""
# By default, we raise an exception when text exceeds max_length
# Subclasses can override this method to implement custom handling
if self.max_length is None:
# If max_length is None, no limit is enforced,
# shell not ne here as long as caller handles it
return [
Document(
text,
DocMetadata(
start_index=start_index,
end_index=start_index + len(text),
location=location,
),
)
]
else:
# Text exceeds max_length and we can't split further
raise CanNotSplitUnderMaxLength(
chunk_length=len(text), max_length=self.max_length
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment