Skip to content

Instantly share code, notes, and snippets.

@proger
Created October 26, 2025 23:51
Show Gist options
  • Select an option

  • Save proger/40d40a4ce9ed43decd7e9e4decbe8a98 to your computer and use it in GitHub Desktop.

Select an option

Save proger/40d40a4ce9ed43decd7e9e4decbe8a98 to your computer and use it in GitHub Desktop.

Revisions

  1. proger created this gist Oct 26, 2025.
    199 changes: 199 additions & 0 deletions alignment_stats.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,199 @@
    """
    Print alignment statistics produced by train_mono.py
    See also: https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/diagnostic/analyze_alignments.sh
    """
    import argparse
    from collections import Counter, defaultdict
    from dataclasses import dataclass
    from pathlib import Path
    from typing import Iterable, Iterator, Sequence

    SUBSCRIPT_DIGITS = "₀₁₂₃₄₅₆₇₈₉"


    @dataclass
    class AggregatedStats:
    boundary_counts: dict[str, dict[str, Counter[int]]]
    total_counts: dict[str, int]
    total_frames: dict[str, int]
    num_utterances: int


    def _sequence_to_runs(sequence: Iterable[str]) -> list[tuple[str, int]]:
    iterator: Iterator[str] = iter(sequence)
    try:
    current = next(iterator)
    except StopIteration:
    return []
    runs: list[tuple[str, int]] = []
    length = 1
    for symbol in iterator:
    if symbol == current:
    length += 1
    else:
    runs.append((current, length))
    current = symbol
    length = 1
    runs.append((current, length))
    return runs


    def _collect_stats(sequences: Iterable[Sequence[str]]) -> AggregatedStats:
    boundary_counts: dict[str, dict[str, Counter[int]]] = {
    "begin": defaultdict(Counter),
    "end": defaultdict(Counter),
    "all": defaultdict(Counter),
    }
    total_counts = {"begin": 0, "end": 0, "all": 0}
    total_frames = {"begin": 0, "end": 0, "all": 0}
    num_utterances = 0
    for sequence in sequences:
    runs = _sequence_to_runs(sequence)
    if not runs:
    continue
    num_utterances += 1
    total_counts["all"] += len(runs)
    total_frames["all"] += sum(length for _, length in runs)
    for symbol, length in runs:
    boundary_counts["all"][symbol][length] += 1
    first_symbol, first_len = runs[0]
    boundary_counts["begin"][first_symbol][first_len] += 1
    total_counts["begin"] += 1
    total_frames["begin"] += first_len
    last_symbol, last_len = runs[-1]
    boundary_counts["end"][last_symbol][last_len] += 1
    total_counts["end"] += 1
    total_frames["end"] += last_len
    return AggregatedStats(boundary_counts, total_counts, total_frames, num_utterances)


    def _percentile(lengths: Counter[int], fraction: float) -> int:
    if not lengths:
    return 0
    cutoff = fraction * sum(lengths.values())
    running = 0.0
    for length, count in sorted(lengths.items()):
    running += count
    if running >= cutoff:
    return length
    return 0


    def _mean(lengths: Counter[int]) -> float:
    total_occurrences = sum(lengths.values())
    if total_occurrences == 0:
    return 0.0
    total_frames = sum(length * count for length, count in lengths.items())
    return total_frames / total_occurrences


    def _symbol_summary(lengths: Counter[int]) -> tuple[int, int, float, int]:
    occurrences = sum(lengths.values())
    frames = sum(length * count for length, count in lengths.items())
    mean = _mean(lengths)
    median = _percentile(lengths, 0.5)
    p95 = _percentile(lengths, 0.95)
    return occurrences, frames, mean, median, p95


    def print_alignment_statistics(
    sequences: Iterable[Sequence[str]],
    variant_counts: dict[str, Counter[str]] | None = None,
    frequency_cutoff: float = 0.0,
    silence_symbol: str = "_",
    ) -> None:
    stats = _collect_stats(sequences)
    print(
    f"[alignment_stats] analyzed {stats.num_utterances} utterances "
    f"with {stats.total_frames['all']} aligned frames",
    flush=True,
    )
    begin_lengths = stats.boundary_counts["begin"][silence_symbol]
    end_lengths = stats.boundary_counts["end"][silence_symbol]
    begin_occurrences = sum(begin_lengths.values())
    end_occurrences = sum(end_lengths.values())
    begin_frequency = 100.0 * begin_occurrences / max(stats.total_counts["begin"], 1)
    end_frequency = 100.0 * end_occurrences / max(stats.total_counts["end"], 1)
    begin_mean = _mean(begin_lengths)
    end_mean = _mean(end_lengths)
    begin_median = _percentile(begin_lengths, 0.5)
    end_median = _percentile(end_lengths, 0.5)
    print(
    f"[alignment_stats] At utterance begin, '{silence_symbol}' appears {begin_frequency:.1f}% "
    f"of the time; when seen, duration median={begin_median} mean={begin_mean:.1f} frames.",
    flush=True,
    )
    print(
    f"[alignment_stats] At utterance end, '{silence_symbol}' appears {end_frequency:.1f}% "
    f"of the time; when seen, duration median={end_median} mean={end_mean:.1f} frames.",
    flush=True,
    )
    overall_frames = stats.total_frames["all"]
    symbol_summaries: list[tuple[str, int, float, int, int]] = []
    for symbol, lengths in stats.boundary_counts["all"].items():
    occurrences, frames, mean, median, p95 = _symbol_summary(lengths)
    if overall_frames == 0:
    occupancy = 0.0
    else:
    occupancy = 100.0 * frames / overall_frames
    if occupancy < frequency_cutoff:
    continue
    symbol_summaries.append((symbol, frames, mean, median, p95))
    symbol_summaries.sort(key=lambda item: item[1], reverse=True)
    for symbol, frames, mean, median, p95 in symbol_summaries:
    occupancy = 100.0 * frames / max(overall_frames, 1)
    variant_text = ""
    if variant_counts and symbol in variant_counts:
    variants_sorted = sorted(
    variant_counts[symbol].items(), key=lambda item: item[1], reverse=True
    )
    formatted = ", ".join(f"{variant}×{count}" for variant, count in variants_sorted)
    if formatted:
    variant_text = f" variants: {formatted}"
    print(
    f"[alignment_stats] {symbol!r} occupies {occupancy:.2f}% of frames; "
    f"duration median={median} mean={mean:.1f} p95={p95} frames.{variant_text}",
    flush=True,
    )


    def strip_variant(token: str) -> str:
    return token.rstrip(SUBSCRIPT_DIGITS)


    def read_alignments(align_path: Path) -> tuple[list[list[str]], dict[str, Counter[str]]]:
    sequences: list[list[str]] = []
    variant_counts: dict[str, Counter[str]] = defaultdict(Counter)
    with align_path.open("r", encoding="utf-8") as handle:
    for line in handle:
    parts = line.strip().split()
    if len(parts) <= 1:
    continue
    symbols: list[str] = []
    for token in parts[1:]:
    base = strip_variant(token)
    symbols.append(base)
    variant_counts[base][token] += 1
    sequences.append(symbols)
    return sequences, variant_counts


    def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
    description="Analyze alignment text files produced by train_mono.py"
    )
    parser.add_argument("alignments", type=Path, help="Path to alignments.txt")
    parser.add_argument(
    "--frequency-cutoff",
    type=float,
    default=0.0,
    help="Minimum percentage of frame occupancy for reporting overall stats",
    )
    return parser.parse_args()


    if __name__ == "__main__":
    args = parse_args()
    sequences, variant_counts = read_alignments(args.alignments)
    print_alignment_statistics(sequences, variant_counts, args.frequency_cutoff)