Created
October 26, 2025 23:51
-
-
Save proger/40d40a4ce9ed43decd7e9e4decbe8a98 to your computer and use it in GitHub Desktop.
Revisions
-
proger created this gist
Oct 26, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,199 @@ """ Print alignment statistics produced by train_mono.py See also: https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/diagnostic/analyze_alignments.sh """ import argparse from collections import Counter, defaultdict from dataclasses import dataclass from pathlib import Path from typing import Iterable, Iterator, Sequence SUBSCRIPT_DIGITS = "₀₁₂₃₄₅₆₇₈₉" @dataclass class AggregatedStats: boundary_counts: dict[str, dict[str, Counter[int]]] total_counts: dict[str, int] total_frames: dict[str, int] num_utterances: int def _sequence_to_runs(sequence: Iterable[str]) -> list[tuple[str, int]]: iterator: Iterator[str] = iter(sequence) try: current = next(iterator) except StopIteration: return [] runs: list[tuple[str, int]] = [] length = 1 for symbol in iterator: if symbol == current: length += 1 else: runs.append((current, length)) current = symbol length = 1 runs.append((current, length)) return runs def _collect_stats(sequences: Iterable[Sequence[str]]) -> AggregatedStats: boundary_counts: dict[str, dict[str, Counter[int]]] = { "begin": defaultdict(Counter), "end": defaultdict(Counter), "all": defaultdict(Counter), } total_counts = {"begin": 0, "end": 0, "all": 0} total_frames = {"begin": 0, "end": 0, "all": 0} num_utterances = 0 for sequence in sequences: runs = _sequence_to_runs(sequence) if not runs: continue num_utterances += 1 total_counts["all"] += len(runs) total_frames["all"] += sum(length for _, length in runs) for symbol, length in runs: boundary_counts["all"][symbol][length] += 1 first_symbol, first_len = runs[0] boundary_counts["begin"][first_symbol][first_len] += 1 total_counts["begin"] += 1 total_frames["begin"] += first_len last_symbol, last_len = runs[-1] boundary_counts["end"][last_symbol][last_len] += 1 total_counts["end"] += 1 total_frames["end"] += last_len return AggregatedStats(boundary_counts, total_counts, total_frames, num_utterances) def _percentile(lengths: Counter[int], fraction: float) -> int: if not lengths: return 0 cutoff = fraction * sum(lengths.values()) running = 0.0 for length, count in sorted(lengths.items()): running += count if running >= cutoff: return length return 0 def _mean(lengths: Counter[int]) -> float: total_occurrences = sum(lengths.values()) if total_occurrences == 0: return 0.0 total_frames = sum(length * count for length, count in lengths.items()) return total_frames / total_occurrences def _symbol_summary(lengths: Counter[int]) -> tuple[int, int, float, int]: occurrences = sum(lengths.values()) frames = sum(length * count for length, count in lengths.items()) mean = _mean(lengths) median = _percentile(lengths, 0.5) p95 = _percentile(lengths, 0.95) return occurrences, frames, mean, median, p95 def print_alignment_statistics( sequences: Iterable[Sequence[str]], variant_counts: dict[str, Counter[str]] | None = None, frequency_cutoff: float = 0.0, silence_symbol: str = "_", ) -> None: stats = _collect_stats(sequences) print( f"[alignment_stats] analyzed {stats.num_utterances} utterances " f"with {stats.total_frames['all']} aligned frames", flush=True, ) begin_lengths = stats.boundary_counts["begin"][silence_symbol] end_lengths = stats.boundary_counts["end"][silence_symbol] begin_occurrences = sum(begin_lengths.values()) end_occurrences = sum(end_lengths.values()) begin_frequency = 100.0 * begin_occurrences / max(stats.total_counts["begin"], 1) end_frequency = 100.0 * end_occurrences / max(stats.total_counts["end"], 1) begin_mean = _mean(begin_lengths) end_mean = _mean(end_lengths) begin_median = _percentile(begin_lengths, 0.5) end_median = _percentile(end_lengths, 0.5) print( f"[alignment_stats] At utterance begin, '{silence_symbol}' appears {begin_frequency:.1f}% " f"of the time; when seen, duration median={begin_median} mean={begin_mean:.1f} frames.", flush=True, ) print( f"[alignment_stats] At utterance end, '{silence_symbol}' appears {end_frequency:.1f}% " f"of the time; when seen, duration median={end_median} mean={end_mean:.1f} frames.", flush=True, ) overall_frames = stats.total_frames["all"] symbol_summaries: list[tuple[str, int, float, int, int]] = [] for symbol, lengths in stats.boundary_counts["all"].items(): occurrences, frames, mean, median, p95 = _symbol_summary(lengths) if overall_frames == 0: occupancy = 0.0 else: occupancy = 100.0 * frames / overall_frames if occupancy < frequency_cutoff: continue symbol_summaries.append((symbol, frames, mean, median, p95)) symbol_summaries.sort(key=lambda item: item[1], reverse=True) for symbol, frames, mean, median, p95 in symbol_summaries: occupancy = 100.0 * frames / max(overall_frames, 1) variant_text = "" if variant_counts and symbol in variant_counts: variants_sorted = sorted( variant_counts[symbol].items(), key=lambda item: item[1], reverse=True ) formatted = ", ".join(f"{variant}×{count}" for variant, count in variants_sorted) if formatted: variant_text = f" variants: {formatted}" print( f"[alignment_stats] {symbol!r} occupies {occupancy:.2f}% of frames; " f"duration median={median} mean={mean:.1f} p95={p95} frames.{variant_text}", flush=True, ) def strip_variant(token: str) -> str: return token.rstrip(SUBSCRIPT_DIGITS) def read_alignments(align_path: Path) -> tuple[list[list[str]], dict[str, Counter[str]]]: sequences: list[list[str]] = [] variant_counts: dict[str, Counter[str]] = defaultdict(Counter) with align_path.open("r", encoding="utf-8") as handle: for line in handle: parts = line.strip().split() if len(parts) <= 1: continue symbols: list[str] = [] for token in parts[1:]: base = strip_variant(token) symbols.append(base) variant_counts[base][token] += 1 sequences.append(symbols) return sequences, variant_counts def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Analyze alignment text files produced by train_mono.py" ) parser.add_argument("alignments", type=Path, help="Path to alignments.txt") parser.add_argument( "--frequency-cutoff", type=float, default=0.0, help="Minimum percentage of frame occupancy for reporting overall stats", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() sequences, variant_counts = read_alignments(args.alignments) print_alignment_statistics(sequences, variant_counts, args.frequency_cutoff)