Created
April 22, 2025 04:14
-
-
Save wyfdev/7e5882555b24319d7ebfcd09bf3d5347 to your computer and use it in GitHub Desktop.
Revisions
-
wyfdev created this gist
Apr 22, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,310 @@ """ Hierarchical text splitter using regex delimiters. Delimiters are specified as a list of (pattern, name) tuples, ordered from highest to lowest level. The splitter recursively divides text into chunks using these delimiters. Splitting logic: 1. Prefer splitting at the highest delimiter level, ensuring resulting chunks do not exceed chunk_size. 2. Merge adjacent chunks at the same level if their combined length is within chunk_size. 3. If a chunk remains too large, recursively split it using the next lower delimiter level. 4. When splitting, aim for balanced chunk sizes. Delimiters are regular expressions, e.g.: [ (r'^#.*$', 'title'), (r'^##.*$', 'chapter'), (r'^###.*$', 'section') ] Each tuple contains a regex pattern and a delimiter name (used for metadata; can be None). Text is split strictly at delimiters. A default ultimate delimiter r"(?<=\n)(?=[^\n])" (start of non-empty new line) is appended as a fallback. If text cannot be split by any delimiter, it remains as a single chunk, even if it exceeds chunk_size. If a chunk exceeds max_length, CanNotSplitUnderMaxLength is raised. """ from re import finditer as re_finditer, Match as ReMatch, MULTILINE as RE_MULTILINE from dataclasses import dataclass, field from typing import Optional, Dict, NoReturn class CanNotSplitUnderMaxLength(Exception): """Raised when a chunk cannot be split under the specified max_length.""" def __init__(self, chunk_length: int, max_length: int): self.chunk_length = chunk_length self.max_length = max_length super().__init__( f"Cannot split text into chunks smaller than max_length={max_length}. " f"Chunk minimized with length={chunk_length}." ) @dataclass class DocMetadata: """ Metadata for a document chunk, including start and end indices and hierarchical location. """ start_index: int = 0 end_index: int = 0 location: Dict[str, str] = field(default_factory=dict) def __post_init__(self): """Validate metadata after initialization.""" if self.start_index < 0: raise ValueError("start_index must be non-negative") if self.end_index < self.start_index: raise ValueError("end_index must be >= start_index") def get_length(self) -> int: """Return the length of the corresponding text.""" return self.end_index - self.start_index def add_location_info(self, key: str, value: str) -> None: """Add information to the location dictionary.""" self.location[key] = value class Document: """ Represents a chunk of text and its associated metadata. """ def __init__(self, text: str, metadata: DocMetadata): """ Initialize a Document with text and associated metadata. Args: text (str): The document text. metadata (DocMetadata): Metadata for the document. """ self.text = text self.metadata = metadata # Validate that text length matches metadata if len(text) != metadata.get_length() and metadata.end_index != 0: raise ValueError( f"Text length ({len(text)}) does not match metadata length " f"({metadata.get_length()})" ) @property def length(self) -> int: """Return the length of the document text.""" return len(self.text) def __eq__(self, other): if not isinstance(other, Document): return False return self.text == other.text and self.metadata == other.metadata def __repr__(self): """ Return a string representation of the Document, showing a preview of the text and metadata. """ if len(self.text) <= 25: preview = repr(self.text) else: preview = repr(self.text[:10] + "..." + self.text[-10:]) return f"Document(text={preview}, metadata={self.metadata})" def __str__(self): """Return a string representation focusing on content.""" return f"Document({len(self.text)} chars, location={self.metadata.location})" class Splitter: """ Hierarchical text splitter that splits text into chunks using a list of regex delimiters. """ delimiter: list[tuple[str, str | None]] default_ultimate_delimiter = r"(?<=\n)(?=[^\n])" CanNotSplitUnderMaxLength = CanNotSplitUnderMaxLength def __init__( self, delimiter: list[tuple[str, str | None]], chunk_size: int = 1000, ultimate_delimiter: str | None = default_ultimate_delimiter, max_length: Optional[int] = None, ): """ Initialize the Splitter. Args: delimiter (list[tuple[str, str]]): List of (regex pattern, name) tuples for splitting. chunk_size (int): Preferred maximum chunk size. ultimate_delimiter (str): Fallback delimiter if no others are found. max_length (Optional[int]): Hard maximum chunk length. """ self.delimiter = ( delimiter.copy() ) # Make a copy to avoid modifying the input list if ultimate_delimiter is not None: self.delimiter.append((ultimate_delimiter, None)) self.chunk_size = chunk_size self.max_length = max_length def split_text(self, text: str): """ Split the input text into chunks using hierarchical delimiters. Args: text (str): The text to split. Returns: list[Document]: List of Document objects representing the chunks. """ # Start recursive splitting from the highest delimiter level (0) return self._split_recursive(text, 0, 0) def _split_recursive(self, text: str, level: int, base_index: int, location: Optional[Dict[str, str]] = None) -> list[Document]: """ Recursively split text using delimiters at the given level. Args: text (str): The text to split. level (int): Current delimiter level. base_index (int): Base index for chunk start positions. location (Dict[str, str], optional): Metadata location tracking. Returns: list[Document]: List of Document objects for the split chunks. Raises: CanNotSplitUnderMaxLength: If a chunk cannot be split under max_length. """ if location is None: location = {} # If we've exhausted all delimiters, return the text as a single chunk if level >= len(self.delimiter): # Check max_length constraint if self.max_length is not None and len(text) > self.max_length: # This might raise an exception by default if max_length is set return self.handle_exceed_max_length(text, start_index=base_index, location=location.copy()) else: return [ Document( text, DocMetadata( start_index=base_index, end_index=base_index + len(text), location=location.copy(), ), ) ] delimiter_pattern, delimiter_name = self.delimiter[level] # Use regex to find delimiter positions but do not remove them from text matches = list(re_finditer(delimiter_pattern, text, RE_MULTILINE)) split_points = [0] + [m.start() for m in matches] + [len(text)] # Build a lookup dictionary for faster match retrieval delimiter_dict = {m.start(): m for m in matches} # Try to merge parts to fit chunk_size and avoid going to lower levels if possible i = 0 while i < len(split_points) - 1: chunk_start = split_points[i] chunk_end = split_points[i + 1] chunk = text[chunk_start:chunk_end] next_i = i + 1 # Try to merge with following chunks if their combined length stays under chunk_size while ( next_i < len(split_points) - 1 and len(text[chunk_start:split_points[next_i + 1]]) <= self.chunk_size ): chunk_end = split_points[next_i + 1] chunk = text[chunk_start:chunk_end] next_i += 1 # Copy parent location data and add current level information chunk_location = location.copy() start_match: Optional[ReMatch[str]] = delimiter_dict.get(chunk_start) if start_match and delimiter_name: chunk_location[delimiter_name] = start_match.group() # Process the chunk if not chunk: # Empty chunk, skip to next pass elif len(chunk) > self.chunk_size and level < len(self.delimiter) - 1: # Chunk exceeds preferred size, split at lower level # Pass down the structural location information yield from self._split_recursive( chunk, level + 1, base_index + chunk_start, chunk_location ) else: # Check max_length constraint (hard limit) if self.max_length and len(chunk) > self.max_length: # Try to handle the oversized chunk yield from self.handle_exceed_max_length( text=chunk, start_index=base_index + chunk_start, location=chunk_location.copy(), ) else: yield Document( chunk, DocMetadata( start_index=base_index + chunk_start, end_index=base_index + chunk_end, location=chunk_location, ), ) i = next_i def handle_exceed_max_length( self, text: str, start_index: int, location: dict ) -> list[Document] | NoReturn: """ Check if text exceeds max_length and handle accordingly. Overwriting this method allows for custom handling of oversized text. for example to relax restrictions: if len(text) < self.max_length + 100: return [ Document( text, DocMetadata( start_index=start_index, end_index=start_index + len(text), location=location, ), ) ] else: raise CanNotSplitUnderMaxLength( chunk_length=len(text), max_length=self.max_length + 100 ) """ # By default, we raise an exception when text exceeds max_length # Subclasses can override this method to implement custom handling if self.max_length is None: # If max_length is None, no limit is enforced, # shell not ne here as long as caller handles it return [ Document( text, DocMetadata( start_index=start_index, end_index=start_index + len(text), location=location, ), ) ] else: # Text exceeds max_length and we can't split further raise CanNotSplitUnderMaxLength( chunk_length=len(text), max_length=self.max_length )