""" Train a GMM and use it to Viterbi-align the unicode character labels of the training set to the audio. Let wav/ contain a directory of 16-bit PCM wav files, text contain a list of filename-transcription pairs. Usage is: $ cat > text << EOF 25153.wav — Я це доскона́ло зна́ю, а ось мене́ ду́же диву́є, як ви мо́жете про це зна́ти! EOF $ python -m train_mono wav/ text exp/ See also: https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_mono.sh """ import argparse import heapq import math import unicodedata import wave from dataclasses import dataclass from pathlib import Path from time import perf_counter import numpy as np from numba import njit @dataclass(frozen=True) class Utterance: utt_id: str path: Path tokens: list[str] features: np.ndarray def parse_args() -> argparse.Namespace: class Formatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): pass parser = argparse.ArgumentParser(description=__doc__, formatter_class=Formatter) parser.add_argument("data_dir", type=Path, help="A directory with audio files") parser.add_argument("text_path", type=Path, help="A transcript file formatted like: 'file.wav a brown fox jumped over the lazy dog' (first column is a file name, rest are words)") parser.add_argument("out_dir", type=Path, help="An experiment directory where checkpoints and alignments will be stored") parser.add_argument("--num-iters", type=int, default=40, help="Number of passes over the entire dataset") parser.add_argument("--max-iter-inc", type=int, default=30, help="How many iterations to grow the model size for") parser.add_argument("--totgauss", type=int, default=1000, help="Total target number of Gaussian components split between all states") parser.add_argument("--power", type=float, default=0.25, help="Gaussian states grow proportionally to occupancy ** power") parser.add_argument("--mixup-perturb", type=float, default=0.01, help="LBG-style Gaussian mean perturbation factor when growing the model") parser.add_argument("--mixup-min-count", type=float, default=3.0, help="Minimum occupancy a state needs to have to split") parser.add_argument("--realign-iters", type=str, default="1 2 3 4 5 6 7 8 9 10 12 14 16 18 20 23 26 29 32 35 38", help="Iteration indices where realignment should happen (from egs/wsj)") parser.add_argument("--init", type=Path, help="Path to snapshot .npy file to initialize model parameters") parser.add_argument("--frame-length", type=float, default=0.025, help="Frame length in seconds") parser.add_argument("--frame-shift", type=float, default=0.010, help="Frame shift in seconds") parser.add_argument("--states-per-symbol", type=int, default=3, help="Number of HMM states per symbol") parser.add_argument("--num-ceps", type=int, default=13, help="Number of cepstral coefficients") parser.add_argument("--num-mels", type=int, default=24, help="Number of mel bands") parser.add_argument("--preemphasis", type=float, default=0.97, help="Pre-emphasis factor for waveform") parser.add_argument("--var-floor", type=float, default=1e-3, help="Variance floor used during model updates") return parser.parse_args() def read_wave(path: Path) -> tuple[np.ndarray, int]: with wave.open(str(path), "rb") as wf: num_channels = wf.getnchannels() sample_width = wf.getsampwidth() sample_rate = wf.getframerate() num_frames = wf.getnframes() assert sample_width == 2, "expect 16-bit PCM" raw = wf.readframes(num_frames) data = np.frombuffer(raw, dtype=np.int16).astype(np.float32) if num_channels == 2: data = data.reshape(-1, 2).mean(axis=1) data /= 32768.0 return data, sample_rate def frame_signal(signal: np.ndarray, sample_rate: int, args: argparse.Namespace) -> np.ndarray: frame_len = int(round(args.frame_length * sample_rate)) frame_step = int(round(args.frame_shift * sample_rate)) num_frames = 1 + max(0, (len(signal) - frame_len) // frame_step) pad_len = (num_frames - 1) * frame_step + frame_len if pad_len > len(signal): signal = np.pad(signal, (0, pad_len - len(signal))) indices = np.arange(frame_len)[None, :] + frame_step * np.arange(num_frames)[:, None] frames = signal[indices] return frames def mel_scale(freq: float | np.ndarray) -> np.ndarray: freq_arr = np.asarray(freq, dtype=np.float64) return 2595.0 * np.log10(1.0 + freq_arr / 700.0) def inverse_mel(mel: float | np.ndarray) -> np.ndarray: mel_arr = np.asarray(mel, dtype=np.float64) return 700.0 * (10 ** (mel_arr / 2595.0) - 1) def create_mel_filterbank(sample_rate: int, nfft: int, args: argparse.Namespace) -> np.ndarray: low_mel = float(mel_scale(0.0)) high_mel = float(mel_scale(sample_rate / 2.0)) mel_points = np.linspace(low_mel, high_mel, args.num_mels + 2) hz_points = inverse_mel(mel_points) bin_points = np.floor((nfft + 1) * hz_points / sample_rate).astype(int) filters = np.zeros((args.num_mels, nfft // 2 + 1), dtype=np.float32) for m in range(1, args.num_mels + 1): left = int(bin_points[m - 1]) center = int(bin_points[m]) right = int(bin_points[m + 1]) left = max(left, 0) center = max(center, left + 1) center = min(center, right) right = min(right, nfft // 2) if right <= left: continue for k in range(left, center): filters[m - 1, k] = (k - left) / max(center - left, 1) for k in range(center, right): filters[m - 1, k] = (right - k) / max(right - center, 1) return filters def create_dct_matrix(num_mels: int, num_ceps: int) -> np.ndarray: n = np.arange(num_mels)[None, :] k = np.arange(num_ceps)[:, None] mat = np.sqrt(2.0 / num_mels) * np.cos((np.pi / num_mels) * (n + 0.5) * k) mat[0] *= np.sqrt(0.5) return mat def compute_mfcc(signal: np.ndarray, sample_rate: int, args: argparse.Namespace) -> np.ndarray: emphasized = np.append(signal[0], signal[1:] - args.preemphasis * signal[:-1]) frames = frame_signal(emphasized, sample_rate, args) frames *= np.hamming(frames.shape[1])[None, :] nfft = 1 while nfft < frames.shape[1]: nfft *= 2 spectrum = np.fft.rfft(frames, n=nfft) power = (np.abs(spectrum) ** 2) / nfft fbanks = create_mel_filterbank(sample_rate, nfft, args) mel_energy = np.dot(power, fbanks.T) mel_energy = np.maximum(mel_energy, 1e-10) log_mel = np.log(mel_energy) dct = create_dct_matrix(log_mel.shape[1], args.num_ceps) mfcc = log_mel @ dct.T return mfcc.astype(np.float32) def compute_deltas(feats: np.ndarray, N: int = 2) -> np.ndarray: denom = 2 * sum(d * d for d in range(1, N + 1)) padded = np.pad(feats, ((N, N), (0, 0)), mode="edge") delta = np.zeros_like(feats) for t in range(feats.shape[0]): acc = np.zeros(feats.shape[1], dtype=np.float32) for d in range(1, N + 1): acc += d * (padded[t + N + d] - padded[t + N - d]) delta[t] = acc / denom return delta def logsumexp(values: np.ndarray) -> np.ndarray: max_vals = np.max(values, axis=1, keepdims=True) stabilized = np.exp(values - max_vals) return (max_vals[:, 0] + np.log(np.sum(stabilized, axis=1))).astype(np.float64) def apply_cmvn(feats: np.ndarray) -> np.ndarray: mean = feats.mean(axis=0, keepdims=True) return feats - mean def subscript_number(number: int | float | str) -> str: translation = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉") return f"{number}".translate(translation) def extract_features(path: Path, args: argparse.Namespace) -> np.ndarray: signal, sample_rate = read_wave(path) mfcc = compute_mfcc(signal, sample_rate, args) mfcc = apply_cmvn(mfcc) delta = compute_deltas(mfcc) deltadelta = compute_deltas(delta) return np.concatenate([mfcc, delta, deltadelta], axis=1) def normalize_text(text: str) -> str: cleaned = [] for ch in unicodedata.normalize("NFC", text.casefold()): cat = uni_category(ch) if cat.startswith("P") or ch in {"«", "»", "—", "–", "…", "’", "“", "”"}: continue cleaned.append(" " if ch.isspace() else ch) return "".join(cleaned) def uni_category(ch: str) -> str: return unicodedata.category(ch) def split_graphemes(text: str) -> list[str]: clusters = [] current = "" for ch in text: if not current: current = ch elif unicodedata.combining(ch): current += ch else: clusters.append(current) current = ch if current: clusters.append(current) return clusters def load_transcriptions(text_path: Path) -> dict[str, list[str]]: labels = {} with text_path.open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue key, content = line.split(maxsplit=1) normalized = normalize_text(content.strip()) words = [w for w in normalized.split() if w] tokens: list[str] = [] tokens.append("_") for word in words: graphemes = split_graphemes(word) tokens.extend(graphemes) tokens.append("_") if not words: tokens.append("_") labels[key] = tokens return labels def collect_utterances(args: argparse.Namespace, labels: dict[str, list[str]]) -> list[Utterance]: utterances = [] for wav_path in sorted(args.data_dir.glob("*.wav")): utt_id = wav_path.name if utt_id not in labels: continue feats = extract_features(wav_path, args) if feats.shape[0] == 0: continue utterances.append(Utterance(utt_id=utt_id, path=wav_path, tokens=labels[utt_id], features=feats)) return utterances def load_snapshot(path: Path) -> dict: data = np.load(path, allow_pickle=True) if isinstance(data, dict): return data if data.shape == (): return data.item() return data[0] class DiagGMM: def __init__(self, dim: int, args: argparse.Namespace, num_components: int = 1): self.dim = dim self.args = args self.weights = np.full(num_components, 1.0 / num_components, dtype=np.float32) self.means = np.zeros((num_components, dim), dtype=np.float32) self.vars = np.ones((num_components, dim), dtype=np.float32) self._refresh() def _refresh(self) -> None: self.log_weights = np.log(self.weights + 1e-12).astype(np.float32) log_two_pi = np.log(2 * np.pi) self.log_norm = (-0.5 * (np.log(self.vars).sum(axis=1) + self.dim * log_two_pi)).astype(np.float32) self.inv_vars = 1.0 / self.vars self.mu_inv_vars = self.means * self.inv_vars self.mu2_inv_vars_sum = np.sum(self.means * self.mu_inv_vars, axis=1) def log_likelihood(self, feats: np.ndarray) -> np.ndarray: term1 = feats ** 2 @ self.inv_vars.T term2 = feats @ self.mu_inv_vars.T exponent = 0.5 * (term1 - 2.0 * term2 + self.mu2_inv_vars_sum[None, :]) return self.log_weights + self.log_norm - exponent def posterior_and_loglike(self, feats: np.ndarray) -> tuple[np.ndarray, np.ndarray]: if feats.shape[0] == 0: return np.zeros((0, self.weights.shape[0]), dtype=np.float32), np.zeros(0, dtype=np.float32) log_like = self.log_likelihood(feats) max_log = np.max(log_like, axis=1, keepdims=True) probs = np.exp(log_like - max_log) probs_sum = probs.sum(axis=1, keepdims=True) post = (probs / probs_sum).astype(np.float32) frame_log_like = (max_log[:, 0] + np.log(probs_sum[:, 0])).astype(np.float32) return post, frame_log_like def accumulate(self, feats: np.ndarray, accumulator: "GMMAccumulator") -> np.ndarray: post, frame_log_like = self.posterior_and_loglike(feats) accumulator.add(feats, post) return frame_log_like def update(self, accumulator: "GMMAccumulator", min_count: float) -> float: counts = accumulator.counts total = counts.sum() if total < min_count: return 0.0 means = accumulator.linear / counts[:, None] second = accumulator.quadratic / counts[:, None] vars_ = second - means ** 2 vars_ = np.maximum(vars_, self.args.var_floor) weights = counts / total self.weights = weights self.means = means self.vars = vars_ self._refresh() return total def split_to(self, target_components: int, perturb_factor: float) -> None: while self.weights.shape[0] < target_components: idx = int(np.argmax(self.weights)) mean = self.means[idx] var = self.vars[idx] weight = self.weights[idx] * 0.5 perturb = perturb_factor * np.sqrt(var) new_mean_pos = mean + perturb new_mean_neg = mean - perturb self.weights[idx] = weight self.means[idx] = new_mean_pos self.vars[idx] = var self.weights = np.concatenate([self.weights, [weight]]) self.means = np.vstack([self.means, new_mean_neg]) self.vars = np.vstack([self.vars, var]) self._refresh() class GMMAccumulator: def __init__(self, dim: int, num_components: int): self.counts = np.zeros(num_components, dtype=np.float64) self.linear = np.zeros((num_components, dim), dtype=np.float64) self.quadratic = np.zeros((num_components, dim), dtype=np.float64) def add(self, feats: np.ndarray, post: np.ndarray) -> None: gamma = post.sum(axis=0) self.counts += gamma self.linear += post.T @ feats self.quadratic += post.T @ (feats ** 2) class MonophoneModel: def __init__(self, symbols: list[str], dim: int, args: argparse.Namespace): self.symbols = symbols self.dim = dim self.args = args self.state_to_symbol = [] self.state_variants = [] self.symbol_to_states = {} self.states = [] for sym in symbols: indices = [] for _ in range(args.states_per_symbol): state_id = len(self.states) self.states.append(DiagGMM(dim, args)) self.state_to_symbol.append(sym) self.state_variants.append(len(indices)) indices.append(state_id) self.symbol_to_states[sym] = indices def expand(self, tokens: list[str]) -> np.ndarray: seq = [] for token in tokens: seq.extend(self.symbol_to_states[token]) return np.array(seq, dtype=np.int32) def num_gauss(self) -> int: return sum(state.weights.shape[0] for state in self.states) def equal_alignment(num_frames: int, state_sequence: np.ndarray) -> np.ndarray: repeat = max(1, num_frames // len(state_sequence)) aligned = np.repeat(state_sequence, repeat) if aligned.shape[0] >= num_frames: return aligned[:num_frames] tail = np.full(num_frames - aligned.shape[0], state_sequence[-1], dtype=np.int32) return np.concatenate([aligned, tail]) @njit(cache=True) def _viterbi_align_impl(emissions: np.ndarray, log_self: float, log_next: float) -> tuple[np.ndarray, float]: T, S = emissions.shape viterbi = np.empty((T, S), dtype=np.float64) backptr = np.zeros((T, S), dtype=np.int32) for s in range(S): viterbi[0, s] = -np.inf viterbi[0, 0] = emissions[0, 0] for t in range(1, T): for s in range(S): stay = viterbi[t - 1, s] + log_self best = stay prev_state = s if s > 0: advance = viterbi[t - 1, s - 1] + log_next if advance > best: best = advance prev_state = s - 1 viterbi[t, s] = best + emissions[t, s] backptr[t, s] = prev_state alignment = np.zeros(T, dtype=np.int32) state = S - 1 alignment[T - 1] = state for t in range(T - 1, 0, -1): state = backptr[t, state] alignment[t - 1] = state return alignment, viterbi[T - 1, S - 1] def viterbi_align(model: MonophoneModel, feats: np.ndarray, state_sequence: np.ndarray, return_score: bool = False) -> np.ndarray | tuple[np.ndarray, float]: S = len(state_sequence) T = feats.shape[0] emissions = np.zeros((T, S), dtype=np.float64) for idx, state_id in enumerate(state_sequence): log_like = model.states[state_id].log_likelihood(feats) emissions[:, idx] = logsumexp(log_like) log_self = math.log(0.6) log_next = math.log(0.4) alignment, score = _viterbi_align_impl(emissions, log_self, log_next) if not np.isfinite(score): fallback = equal_alignment(T, state_sequence) if return_score: return fallback, float("-inf") return fallback mapped = np.zeros_like(alignment) for t in range(T): mapped[t] = state_sequence[alignment[t]] if return_score: return mapped, float(score) return mapped def accumulate_stats(model: MonophoneModel, utterances: list[Utterance], alignments: dict[str, np.ndarray]) -> tuple[list[GMMAccumulator], float, int]: accumulators = [] for state in model.states: accumulators.append(GMMAccumulator(model.dim, state.weights.shape[0])) total_like = 0.0 total_frames = 0 for utt in utterances: alignment = alignments[utt.utt_id] feats = utt.features for state_id in np.unique(alignment): selector = alignment == state_id frames = feats[selector] if frames.shape[0] == 0: continue state = model.states[state_id] accumulator = accumulators[state_id] frame_log_like = state.accumulate(frames, accumulator) total_like += float(frame_log_like.sum()) total_frames += frames.shape[0] return accumulators, total_like, total_frames def update_model(model: MonophoneModel, accumulators: list[GMMAccumulator], min_count: float) -> np.ndarray: occs = np.zeros(len(model.states), dtype=np.float64) for idx, (state, acc) in enumerate(zip(model.states, accumulators)): occ = state.update(acc, min_count) if occ == 0.0: occ = acc.counts.sum() occs[idx] = occ return occs @dataclass(order=True) class SplitItem: priority: float index: int num_components: int occ_pow: float occ: float def compute_split_targets(state_occs: np.ndarray, target_total: int, power: float, min_count: float) -> list[int]: num_states = state_occs.shape[0] queue: list[SplitItem] = [] for idx in range(num_states): occ = state_occs[idx] occ_pow = occ ** power if occ > 0 else 0.0 priority = -occ_pow heapq.heappush(queue, SplitItem(priority, idx, 1, occ_pow, occ)) total = num_states while queue and total < target_total: item = heapq.heappop(queue) if item.occ_pow == 0: break if (item.num_components + 1) * min_count >= item.occ: item = SplitItem(0.0, item.index, item.num_components, 0.0, item.occ) else: item = SplitItem(-item.occ_pow / (item.num_components + 1), item.index, item.num_components + 1, item.occ_pow, item.occ) total += 1 heapq.heappush(queue, item) targets = [1] * num_states while queue: item = heapq.heappop(queue) targets[item.index] = item.num_components return targets def mixup(model: MonophoneModel, state_occs: np.ndarray, current_target: int, power: float, min_count: float, perturb: float) -> None: current_total = model.num_gauss() if current_target <= current_total: return targets = compute_split_targets(state_occs, current_target, power, min_count) for state, target in zip(model.states, targets): if target > state.weights.shape[0]: state.split_to(target, perturb) def save_model(model: MonophoneModel, args: argparse.Namespace) -> None: args.out_dir.mkdir(parents=True, exist_ok=True) np.savez(args.out_dir / "final_model.npz", symbols=np.array(model.symbols, dtype=object), state_to_symbol=np.array(model.state_to_symbol, dtype=object), weights=np.array([state.weights for state in model.states], dtype=object), means=np.array([state.means for state in model.states], dtype=object), variances=np.array([state.vars for state in model.states], dtype=object), frame_length=args.frame_length, frame_shift=args.frame_shift, num_ceps=args.num_ceps, states_per_symbol=args.states_per_symbol) def save_alignments(model: MonophoneModel, alignments: dict[str, np.ndarray], align_path: Path) -> None: mapping = model.state_to_symbol align_path.parent.mkdir(parents=True, exist_ok=True) with align_path.open("w", encoding="utf-8") as f: for utt_id, state_ids in sorted(alignments.items()): symbols = [ f"{mapping[state_id]}{subscript_number(model.state_variants[state_id] + 1)}" for state_id in state_ids ] f.write(" ".join([utt_id, *symbols])) f.write("\n") def initialize_model(utterances: list[Utterance], args: argparse.Namespace) -> tuple[MonophoneModel, dict[str, np.ndarray], float, int]: symbols = sorted({token for utt in utterances for token in utt.tokens}) dim = utterances[0].features.shape[1] model = MonophoneModel(symbols, dim, args) if args.init is not None: snapshot = load_snapshot(args.init) weights_list = snapshot["weights"] means_list = snapshot["means"] vars_list = snapshot["variances"] for state, weights, means, variances in zip(model.states, weights_list, means_list, vars_list): state.weights = np.asarray(weights, dtype=np.float32) state.means = np.asarray(means, dtype=np.float32) state.vars = np.asarray(variances, dtype=np.float32) state._refresh() alignments = {} for utt in utterances: sequence = model.expand(utt.tokens) alignment, _ = viterbi_align(model, utt.features, sequence, return_score=True) alignments[utt.utt_id] = alignment _, total_like, total_frames = accumulate_stats(model, utterances, alignments) return model, alignments, total_like, total_frames initial_alignments: dict[str, np.ndarray] = {} for utt in utterances: state_sequence = model.expand(utt.tokens) alignment = equal_alignment(utt.features.shape[0], state_sequence) initial_alignments[utt.utt_id] = alignment accumulators, total_like, total_frames = accumulate_stats(model, utterances, initial_alignments) update_model(model, accumulators, min_count=1.0) return model, initial_alignments, total_like, total_frames def train(args: argparse.Namespace) -> None: labels = load_transcriptions(args.text_path) utterances = collect_utterances(args, labels) init_start = perf_counter() model, alignments, init_like, init_frames = initialize_model(utterances, args) init_time = perf_counter() - init_start realign_iters = {int(x) for x in args.realign_iters.split()} num_gauss = model.num_gauss() inc_gauss = max((args.totgauss - num_gauss) // args.max_iter_inc, 1) current_target = num_gauss init_source = "snapshot" if args.init else "flat" if init_frames > 0: avg_like = init_like / init_frames print(f"[train_mono] init: avg loglike/frame={avg_like:.4f} " f"total_like={init_like:.2f} frames={init_frames} " f"num_gauss={model.num_gauss()} source={init_source} " f"| times: init={init_time:.2f}s", flush=True) else: print(f"[train_mono] init: frames=0 source={init_source} " f"| times: init={init_time:.2f}s", flush=True) for iteration in range(1, args.num_iters): iter_start = perf_counter() realign_time = 0.0 if iteration in realign_iters: print(f"[train_mono] iter {iteration:02d}: realigning", flush=True) realign_begin = perf_counter() new_alignments = {} for utt in utterances: sequence = model.expand(utt.tokens) new_alignments[utt.utt_id] = viterbi_align(model, utt.features, sequence) alignments = new_alignments realign_time = perf_counter() - realign_begin print(f"[train_mono] iter {iteration:02d}: realign done in {realign_time:.2f}s", flush=True) accum_start = perf_counter() accumulators, total_like, total_frames = accumulate_stats(model, utterances, alignments) accum_time = perf_counter() - accum_start update_start = perf_counter() state_occs = update_model(model, accumulators, args.mixup_min_count) update_time = perf_counter() - update_start num_gauss_before_mix = model.num_gauss() if total_frames > 0: avg_like = total_like / total_frames else: avg_like = float("nan") mix_time = 0.0 if iteration <= args.max_iter_inc: current_target = min(args.totgauss, current_target + inc_gauss) mix_start = perf_counter() mixup(model, state_occs, current_target, args.power, args.mixup_min_count, args.mixup_perturb) mix_time = perf_counter() - mix_start num_gauss_after_mix = model.num_gauss() if num_gauss_after_mix != num_gauss_before_mix: print(f"[train_mono] iter {iteration:02d}: mixup -> num_gauss {num_gauss_before_mix} -> {num_gauss_after_mix}", flush=True) iter_time = perf_counter() - iter_start print(f"[train_mono] iter {iteration:02d}: avg loglike/frame={avg_like:.4f} " f"total_like={total_like:.2f} frames={total_frames} " f"num_gauss={model.num_gauss()} " f"| times: realign={realign_time:.2f}s accum={accum_time:.2f}s " f"update={update_time:.2f}s mix={mix_time:.2f}s iter={iter_time:.2f}s", flush=True) snapshot = { "weights": [state.weights for state in model.states], "means": [state.means for state in model.states], "variances": [state.vars for state in model.states], } np.save(args.out_dir / f"model_iter_{iteration}.npy", np.array([snapshot], dtype=object), allow_pickle=True) save_alignments(model, alignments, args.out_dir / f"alignments_iter_{iteration:02d}.txt") save_alignments(model, alignments, args.out_dir / "alignments.txt") save_model(model, args) def main() -> None: args = parse_args() args.out_dir.mkdir(parents=True, exist_ok=True) train(args) if __name__ == "__main__": main()