Skip to content

Instantly share code, notes, and snippets.

@wyfdev
Created April 22, 2025 04:14
Show Gist options
  • Save wyfdev/7e5882555b24319d7ebfcd09bf3d5347 to your computer and use it in GitHub Desktop.
Save wyfdev/7e5882555b24319d7ebfcd09bf3d5347 to your computer and use it in GitHub Desktop.

Revisions

  1. wyfdev created this gist Apr 22, 2025.
    310 changes: 310 additions & 0 deletions hierarchical_text_splitter.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,310 @@
    """
    Hierarchical text splitter using regex delimiters.
    Delimiters are specified as a list of (pattern, name) tuples, ordered from highest to lowest level.
    The splitter recursively divides text into chunks using these delimiters.
    Splitting logic:
    1. Prefer splitting at the highest delimiter level, ensuring resulting chunks do not exceed chunk_size.
    2. Merge adjacent chunks at the same level if their combined length is within chunk_size.
    3. If a chunk remains too large, recursively split it using the next lower delimiter level.
    4. When splitting, aim for balanced chunk sizes.
    Delimiters are regular expressions, e.g.:
    [
    (r'^#.*$', 'title'),
    (r'^##.*$', 'chapter'),
    (r'^###.*$', 'section')
    ]
    Each tuple contains a regex pattern and a delimiter name (used for metadata; can be None).
    Text is split strictly at delimiters.
    A default ultimate delimiter r"(?<=\n)(?=[^\n])" (start of non-empty new line) is appended as a fallback.
    If text cannot be split by any delimiter, it remains as a single chunk, even if it exceeds chunk_size.
    If a chunk exceeds max_length, CanNotSplitUnderMaxLength is raised.
    """

    from re import finditer as re_finditer, Match as ReMatch, MULTILINE as RE_MULTILINE
    from dataclasses import dataclass, field
    from typing import Optional, Dict, NoReturn


    class CanNotSplitUnderMaxLength(Exception):
    """Raised when a chunk cannot be split under the specified max_length."""

    def __init__(self, chunk_length: int, max_length: int):
    self.chunk_length = chunk_length
    self.max_length = max_length
    super().__init__(
    f"Cannot split text into chunks smaller than max_length={max_length}. "
    f"Chunk minimized with length={chunk_length}."
    )


    @dataclass
    class DocMetadata:
    """
    Metadata for a document chunk, including start and end indices and hierarchical location.
    """

    start_index: int = 0
    end_index: int = 0
    location: Dict[str, str] = field(default_factory=dict)

    def __post_init__(self):
    """Validate metadata after initialization."""
    if self.start_index < 0:
    raise ValueError("start_index must be non-negative")
    if self.end_index < self.start_index:
    raise ValueError("end_index must be >= start_index")

    def get_length(self) -> int:
    """Return the length of the corresponding text."""
    return self.end_index - self.start_index

    def add_location_info(self, key: str, value: str) -> None:
    """Add information to the location dictionary."""
    self.location[key] = value


    class Document:
    """
    Represents a chunk of text and its associated metadata.
    """

    def __init__(self, text: str, metadata: DocMetadata):
    """
    Initialize a Document with text and associated metadata.
    Args:
    text (str): The document text.
    metadata (DocMetadata): Metadata for the document.
    """
    self.text = text
    self.metadata = metadata

    # Validate that text length matches metadata
    if len(text) != metadata.get_length() and metadata.end_index != 0:
    raise ValueError(
    f"Text length ({len(text)}) does not match metadata length "
    f"({metadata.get_length()})"
    )

    @property
    def length(self) -> int:
    """Return the length of the document text."""
    return len(self.text)

    def __eq__(self, other):
    if not isinstance(other, Document):
    return False
    return self.text == other.text and self.metadata == other.metadata

    def __repr__(self):
    """
    Return a string representation of the Document, showing a preview of the text and metadata.
    """
    if len(self.text) <= 25:
    preview = repr(self.text)
    else:
    preview = repr(self.text[:10] + "..." + self.text[-10:])
    return f"Document(text={preview}, metadata={self.metadata})"

    def __str__(self):
    """Return a string representation focusing on content."""
    return f"Document({len(self.text)} chars, location={self.metadata.location})"


    class Splitter:
    """
    Hierarchical text splitter that splits text into chunks using a list of regex delimiters.
    """

    delimiter: list[tuple[str, str | None]]
    default_ultimate_delimiter = r"(?<=\n)(?=[^\n])"

    CanNotSplitUnderMaxLength = CanNotSplitUnderMaxLength

    def __init__(
    self,
    delimiter: list[tuple[str, str | None]],
    chunk_size: int = 1000,
    ultimate_delimiter: str | None = default_ultimate_delimiter,
    max_length: Optional[int] = None,
    ):
    """
    Initialize the Splitter.
    Args:
    delimiter (list[tuple[str, str]]): List of (regex pattern, name) tuples for splitting.
    chunk_size (int): Preferred maximum chunk size.
    ultimate_delimiter (str): Fallback delimiter if no others are found.
    max_length (Optional[int]): Hard maximum chunk length.
    """
    self.delimiter = (
    delimiter.copy()
    ) # Make a copy to avoid modifying the input list
    if ultimate_delimiter is not None:
    self.delimiter.append((ultimate_delimiter, None))
    self.chunk_size = chunk_size
    self.max_length = max_length

    def split_text(self, text: str):
    """
    Split the input text into chunks using hierarchical delimiters.
    Args:
    text (str): The text to split.
    Returns:
    list[Document]: List of Document objects representing the chunks.
    """
    # Start recursive splitting from the highest delimiter level (0)
    return self._split_recursive(text, 0, 0)

    def _split_recursive(self, text: str, level: int, base_index: int, location: Optional[Dict[str, str]] = None) -> list[Document]:
    """
    Recursively split text using delimiters at the given level.
    Args:
    text (str): The text to split.
    level (int): Current delimiter level.
    base_index (int): Base index for chunk start positions.
    location (Dict[str, str], optional): Metadata location tracking.
    Returns:
    list[Document]: List of Document objects for the split chunks.
    Raises:
    CanNotSplitUnderMaxLength: If a chunk cannot be split under max_length.
    """
    if location is None:
    location = {}

    # If we've exhausted all delimiters, return the text as a single chunk
    if level >= len(self.delimiter):
    # Check max_length constraint
    if self.max_length is not None and len(text) > self.max_length:
    # This might raise an exception by default if max_length is set
    return self.handle_exceed_max_length(text, start_index=base_index, location=location.copy())
    else:
    return [
    Document(
    text,
    DocMetadata(
    start_index=base_index,
    end_index=base_index + len(text),
    location=location.copy(),
    ),
    )
    ]

    delimiter_pattern, delimiter_name = self.delimiter[level]

    # Use regex to find delimiter positions but do not remove them from text
    matches = list(re_finditer(delimiter_pattern, text, RE_MULTILINE))

    split_points = [0] + [m.start() for m in matches] + [len(text)]

    # Build a lookup dictionary for faster match retrieval
    delimiter_dict = {m.start(): m for m in matches}

    # Try to merge parts to fit chunk_size and avoid going to lower levels if possible
    i = 0
    while i < len(split_points) - 1:
    chunk_start = split_points[i]
    chunk_end = split_points[i + 1]
    chunk = text[chunk_start:chunk_end]
    next_i = i + 1

    # Try to merge with following chunks if their combined length stays under chunk_size
    while (
    next_i < len(split_points) - 1
    and
    len(text[chunk_start:split_points[next_i + 1]]) <= self.chunk_size
    ):
    chunk_end = split_points[next_i + 1]
    chunk = text[chunk_start:chunk_end]
    next_i += 1

    # Copy parent location data and add current level information
    chunk_location = location.copy()
    start_match: Optional[ReMatch[str]] = delimiter_dict.get(chunk_start)
    if start_match and delimiter_name:
    chunk_location[delimiter_name] = start_match.group()

    # Process the chunk
    if not chunk:
    # Empty chunk, skip to next
    pass
    elif len(chunk) > self.chunk_size and level < len(self.delimiter) - 1:
    # Chunk exceeds preferred size, split at lower level
    # Pass down the structural location information
    yield from self._split_recursive(
    chunk, level + 1, base_index + chunk_start, chunk_location
    )
    else:
    # Check max_length constraint (hard limit)
    if self.max_length and len(chunk) > self.max_length:
    # Try to handle the oversized chunk
    yield from self.handle_exceed_max_length(
    text=chunk,
    start_index=base_index + chunk_start,
    location=chunk_location.copy(),
    )
    else:
    yield Document(
    chunk,
    DocMetadata(
    start_index=base_index + chunk_start,
    end_index=base_index + chunk_end,
    location=chunk_location,
    ),
    )
    i = next_i

    def handle_exceed_max_length(
    self, text: str, start_index: int, location: dict
    ) -> list[Document] | NoReturn:
    """
    Check if text exceeds max_length and handle accordingly.
    Overwriting this method allows for custom handling of oversized text.
    for example to relax restrictions:
    if len(text) < self.max_length + 100:
    return [
    Document(
    text,
    DocMetadata(
    start_index=start_index,
    end_index=start_index + len(text),
    location=location,
    ),
    )
    ]
    else:
    raise CanNotSplitUnderMaxLength(
    chunk_length=len(text), max_length=self.max_length + 100
    )
    """
    # By default, we raise an exception when text exceeds max_length
    # Subclasses can override this method to implement custom handling
    if self.max_length is None:
    # If max_length is None, no limit is enforced,
    # shell not ne here as long as caller handles it
    return [
    Document(
    text,
    DocMetadata(
    start_index=start_index,
    end_index=start_index + len(text),
    location=location,
    ),
    )
    ]
    else:
    # Text exceeds max_length and we can't split further
    raise CanNotSplitUnderMaxLength(
    chunk_length=len(text), max_length=self.max_length
    )