Last active
October 26, 2025 23:50
-
-
Save proger/22667a00e6c1991396a481c3835edc0f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Train a GMM and use it to Viterbi-align the unicode character labels of the training set to the audio. | |
| Let wav/ contain a directory of 16-bit PCM wav files, text contain a list of filename-transcription pairs. Usage is: | |
| $ cat > text << EOF | |
| 25153.wav — Я це доскона́ло зна́ю, а ось мене́ ду́же диву́є, як ви мо́жете про це зна́ти! | |
| EOF | |
| $ python -m train_mono wav/ text exp/ | |
| See also: | |
| https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_mono.sh | |
| """ | |
| import argparse | |
| import heapq | |
| import math | |
| import unicodedata | |
| import wave | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from time import perf_counter | |
| import numpy as np | |
| from numba import njit | |
| @dataclass(frozen=True) | |
| class Utterance: | |
| utt_id: str | |
| path: Path | |
| tokens: list[str] | |
| features: np.ndarray | |
| def parse_args() -> argparse.Namespace: | |
| class Formatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): pass | |
| parser = argparse.ArgumentParser(description=__doc__, formatter_class=Formatter) | |
| parser.add_argument("data_dir", type=Path, help="A directory with audio files") | |
| parser.add_argument("text_path", type=Path, help="A transcript file formatted like: 'file.wav a brown fox jumped over the lazy dog' (first column is a file name, rest are words)") | |
| parser.add_argument("out_dir", type=Path, help="An experiment directory where checkpoints and alignments will be stored") | |
| parser.add_argument("--num-iters", type=int, default=40, help="Number of passes over the entire dataset") | |
| parser.add_argument("--max-iter-inc", type=int, default=30, help="How many iterations to grow the model size for") | |
| parser.add_argument("--totgauss", type=int, default=1000, help="Total target number of Gaussian components split between all states") | |
| parser.add_argument("--power", type=float, default=0.25, help="Gaussian states grow proportionally to occupancy ** power") | |
| parser.add_argument("--mixup-perturb", type=float, default=0.01, help="LBG-style Gaussian mean perturbation factor when growing the model") | |
| parser.add_argument("--mixup-min-count", type=float, default=3.0, help="Minimum occupancy a state needs to have to split") | |
| parser.add_argument("--realign-iters", type=str, | |
| default="1 2 3 4 5 6 7 8 9 10 12 14 16 18 20 23 26 29 32 35 38", | |
| help="Iteration indices where realignment should happen (from egs/wsj)") | |
| parser.add_argument("--init", type=Path, help="Path to snapshot .npy file to initialize model parameters") | |
| parser.add_argument("--frame-length", type=float, default=0.025, help="Frame length in seconds") | |
| parser.add_argument("--frame-shift", type=float, default=0.010, help="Frame shift in seconds") | |
| parser.add_argument("--states-per-symbol", type=int, default=3, help="Number of HMM states per symbol") | |
| parser.add_argument("--num-ceps", type=int, default=13, help="Number of cepstral coefficients") | |
| parser.add_argument("--num-mels", type=int, default=24, help="Number of mel bands") | |
| parser.add_argument("--preemphasis", type=float, default=0.97, help="Pre-emphasis factor for waveform") | |
| parser.add_argument("--var-floor", type=float, default=1e-3, help="Variance floor used during model updates") | |
| return parser.parse_args() | |
| def read_wave(path: Path) -> tuple[np.ndarray, int]: | |
| with wave.open(str(path), "rb") as wf: | |
| num_channels = wf.getnchannels() | |
| sample_width = wf.getsampwidth() | |
| sample_rate = wf.getframerate() | |
| num_frames = wf.getnframes() | |
| assert sample_width == 2, "expect 16-bit PCM" | |
| raw = wf.readframes(num_frames) | |
| data = np.frombuffer(raw, dtype=np.int16).astype(np.float32) | |
| if num_channels == 2: | |
| data = data.reshape(-1, 2).mean(axis=1) | |
| data /= 32768.0 | |
| return data, sample_rate | |
| def frame_signal(signal: np.ndarray, sample_rate: int, args: argparse.Namespace) -> np.ndarray: | |
| frame_len = int(round(args.frame_length * sample_rate)) | |
| frame_step = int(round(args.frame_shift * sample_rate)) | |
| num_frames = 1 + max(0, (len(signal) - frame_len) // frame_step) | |
| pad_len = (num_frames - 1) * frame_step + frame_len | |
| if pad_len > len(signal): | |
| signal = np.pad(signal, (0, pad_len - len(signal))) | |
| indices = np.arange(frame_len)[None, :] + frame_step * np.arange(num_frames)[:, None] | |
| frames = signal[indices] | |
| return frames | |
| def mel_scale(freq: float | np.ndarray) -> np.ndarray: | |
| freq_arr = np.asarray(freq, dtype=np.float64) | |
| return 2595.0 * np.log10(1.0 + freq_arr / 700.0) | |
| def inverse_mel(mel: float | np.ndarray) -> np.ndarray: | |
| mel_arr = np.asarray(mel, dtype=np.float64) | |
| return 700.0 * (10 ** (mel_arr / 2595.0) - 1) | |
| def create_mel_filterbank(sample_rate: int, nfft: int, args: argparse.Namespace) -> np.ndarray: | |
| low_mel = float(mel_scale(0.0)) | |
| high_mel = float(mel_scale(sample_rate / 2.0)) | |
| mel_points = np.linspace(low_mel, high_mel, args.num_mels + 2) | |
| hz_points = inverse_mel(mel_points) | |
| bin_points = np.floor((nfft + 1) * hz_points / sample_rate).astype(int) | |
| filters = np.zeros((args.num_mels, nfft // 2 + 1), dtype=np.float32) | |
| for m in range(1, args.num_mels + 1): | |
| left = int(bin_points[m - 1]) | |
| center = int(bin_points[m]) | |
| right = int(bin_points[m + 1]) | |
| left = max(left, 0) | |
| center = max(center, left + 1) | |
| center = min(center, right) | |
| right = min(right, nfft // 2) | |
| if right <= left: | |
| continue | |
| for k in range(left, center): | |
| filters[m - 1, k] = (k - left) / max(center - left, 1) | |
| for k in range(center, right): | |
| filters[m - 1, k] = (right - k) / max(right - center, 1) | |
| return filters | |
| def create_dct_matrix(num_mels: int, num_ceps: int) -> np.ndarray: | |
| n = np.arange(num_mels)[None, :] | |
| k = np.arange(num_ceps)[:, None] | |
| mat = np.sqrt(2.0 / num_mels) * np.cos((np.pi / num_mels) * (n + 0.5) * k) | |
| mat[0] *= np.sqrt(0.5) | |
| return mat | |
| def compute_mfcc(signal: np.ndarray, sample_rate: int, args: argparse.Namespace) -> np.ndarray: | |
| emphasized = np.append(signal[0], signal[1:] - args.preemphasis * signal[:-1]) | |
| frames = frame_signal(emphasized, sample_rate, args) | |
| frames *= np.hamming(frames.shape[1])[None, :] | |
| nfft = 1 | |
| while nfft < frames.shape[1]: | |
| nfft *= 2 | |
| spectrum = np.fft.rfft(frames, n=nfft) | |
| power = (np.abs(spectrum) ** 2) / nfft | |
| fbanks = create_mel_filterbank(sample_rate, nfft, args) | |
| mel_energy = np.dot(power, fbanks.T) | |
| mel_energy = np.maximum(mel_energy, 1e-10) | |
| log_mel = np.log(mel_energy) | |
| dct = create_dct_matrix(log_mel.shape[1], args.num_ceps) | |
| mfcc = log_mel @ dct.T | |
| return mfcc.astype(np.float32) | |
| def compute_deltas(feats: np.ndarray, N: int = 2) -> np.ndarray: | |
| denom = 2 * sum(d * d for d in range(1, N + 1)) | |
| padded = np.pad(feats, ((N, N), (0, 0)), mode="edge") | |
| delta = np.zeros_like(feats) | |
| for t in range(feats.shape[0]): | |
| acc = np.zeros(feats.shape[1], dtype=np.float32) | |
| for d in range(1, N + 1): | |
| acc += d * (padded[t + N + d] - padded[t + N - d]) | |
| delta[t] = acc / denom | |
| return delta | |
| def logsumexp(values: np.ndarray) -> np.ndarray: | |
| max_vals = np.max(values, axis=1, keepdims=True) | |
| stabilized = np.exp(values - max_vals) | |
| return (max_vals[:, 0] + np.log(np.sum(stabilized, axis=1))).astype(np.float64) | |
| def apply_cmvn(feats: np.ndarray) -> np.ndarray: | |
| mean = feats.mean(axis=0, keepdims=True) | |
| return feats - mean | |
| def subscript_number(number: int | float | str) -> str: | |
| translation = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉") | |
| return f"{number}".translate(translation) | |
| def extract_features(path: Path, args: argparse.Namespace) -> np.ndarray: | |
| signal, sample_rate = read_wave(path) | |
| mfcc = compute_mfcc(signal, sample_rate, args) | |
| mfcc = apply_cmvn(mfcc) | |
| delta = compute_deltas(mfcc) | |
| deltadelta = compute_deltas(delta) | |
| return np.concatenate([mfcc, delta, deltadelta], axis=1) | |
| def normalize_text(text: str) -> str: | |
| cleaned = [] | |
| for ch in unicodedata.normalize("NFC", text.casefold()): | |
| cat = uni_category(ch) | |
| if cat.startswith("P") or ch in {"«", "»", "—", "–", "…", "’", "“", "”"}: | |
| continue | |
| cleaned.append(" " if ch.isspace() else ch) | |
| return "".join(cleaned) | |
| def uni_category(ch: str) -> str: | |
| return unicodedata.category(ch) | |
| def split_graphemes(text: str) -> list[str]: | |
| clusters = [] | |
| current = "" | |
| for ch in text: | |
| if not current: | |
| current = ch | |
| elif unicodedata.combining(ch): | |
| current += ch | |
| else: | |
| clusters.append(current) | |
| current = ch | |
| if current: | |
| clusters.append(current) | |
| return clusters | |
| def load_transcriptions(text_path: Path) -> dict[str, list[str]]: | |
| labels = {} | |
| with text_path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| key, content = line.split(maxsplit=1) | |
| normalized = normalize_text(content.strip()) | |
| words = [w for w in normalized.split() if w] | |
| tokens: list[str] = [] | |
| tokens.append("_") | |
| for word in words: | |
| graphemes = split_graphemes(word) | |
| tokens.extend(graphemes) | |
| tokens.append("_") | |
| if not words: | |
| tokens.append("_") | |
| labels[key] = tokens | |
| return labels | |
| def collect_utterances(args: argparse.Namespace, labels: dict[str, list[str]]) -> list[Utterance]: | |
| utterances = [] | |
| for wav_path in sorted(args.data_dir.glob("*.wav")): | |
| utt_id = wav_path.name | |
| if utt_id not in labels: | |
| continue | |
| feats = extract_features(wav_path, args) | |
| if feats.shape[0] == 0: | |
| continue | |
| utterances.append(Utterance(utt_id=utt_id, path=wav_path, | |
| tokens=labels[utt_id], features=feats)) | |
| return utterances | |
| def load_snapshot(path: Path) -> dict: | |
| data = np.load(path, allow_pickle=True) | |
| if isinstance(data, dict): | |
| return data | |
| if data.shape == (): | |
| return data.item() | |
| return data[0] | |
| class DiagGMM: | |
| def __init__(self, dim: int, args: argparse.Namespace, num_components: int = 1): | |
| self.dim = dim | |
| self.args = args | |
| self.weights = np.full(num_components, 1.0 / num_components, dtype=np.float32) | |
| self.means = np.zeros((num_components, dim), dtype=np.float32) | |
| self.vars = np.ones((num_components, dim), dtype=np.float32) | |
| self._refresh() | |
| def _refresh(self) -> None: | |
| self.log_weights = np.log(self.weights + 1e-12).astype(np.float32) | |
| log_two_pi = np.log(2 * np.pi) | |
| self.log_norm = (-0.5 * (np.log(self.vars).sum(axis=1) + self.dim * log_two_pi)).astype(np.float32) | |
| self.inv_vars = 1.0 / self.vars | |
| self.mu_inv_vars = self.means * self.inv_vars | |
| self.mu2_inv_vars_sum = np.sum(self.means * self.mu_inv_vars, axis=1) | |
| def log_likelihood(self, feats: np.ndarray) -> np.ndarray: | |
| term1 = feats ** 2 @ self.inv_vars.T | |
| term2 = feats @ self.mu_inv_vars.T | |
| exponent = 0.5 * (term1 - 2.0 * term2 + self.mu2_inv_vars_sum[None, :]) | |
| return self.log_weights + self.log_norm - exponent | |
| def posterior_and_loglike(self, feats: np.ndarray) -> tuple[np.ndarray, np.ndarray]: | |
| if feats.shape[0] == 0: | |
| return np.zeros((0, self.weights.shape[0]), dtype=np.float32), np.zeros(0, dtype=np.float32) | |
| log_like = self.log_likelihood(feats) | |
| max_log = np.max(log_like, axis=1, keepdims=True) | |
| probs = np.exp(log_like - max_log) | |
| probs_sum = probs.sum(axis=1, keepdims=True) | |
| post = (probs / probs_sum).astype(np.float32) | |
| frame_log_like = (max_log[:, 0] + np.log(probs_sum[:, 0])).astype(np.float32) | |
| return post, frame_log_like | |
| def accumulate(self, feats: np.ndarray, accumulator: "GMMAccumulator") -> np.ndarray: | |
| post, frame_log_like = self.posterior_and_loglike(feats) | |
| accumulator.add(feats, post) | |
| return frame_log_like | |
| def update(self, accumulator: "GMMAccumulator", min_count: float) -> float: | |
| counts = accumulator.counts | |
| total = counts.sum() | |
| if total < min_count: | |
| return 0.0 | |
| means = accumulator.linear / counts[:, None] | |
| second = accumulator.quadratic / counts[:, None] | |
| vars_ = second - means ** 2 | |
| vars_ = np.maximum(vars_, self.args.var_floor) | |
| weights = counts / total | |
| self.weights = weights | |
| self.means = means | |
| self.vars = vars_ | |
| self._refresh() | |
| return total | |
| def split_to(self, target_components: int, perturb_factor: float) -> None: | |
| while self.weights.shape[0] < target_components: | |
| idx = int(np.argmax(self.weights)) | |
| mean = self.means[idx] | |
| var = self.vars[idx] | |
| weight = self.weights[idx] * 0.5 | |
| perturb = perturb_factor * np.sqrt(var) | |
| new_mean_pos = mean + perturb | |
| new_mean_neg = mean - perturb | |
| self.weights[idx] = weight | |
| self.means[idx] = new_mean_pos | |
| self.vars[idx] = var | |
| self.weights = np.concatenate([self.weights, [weight]]) | |
| self.means = np.vstack([self.means, new_mean_neg]) | |
| self.vars = np.vstack([self.vars, var]) | |
| self._refresh() | |
| class GMMAccumulator: | |
| def __init__(self, dim: int, num_components: int): | |
| self.counts = np.zeros(num_components, dtype=np.float64) | |
| self.linear = np.zeros((num_components, dim), dtype=np.float64) | |
| self.quadratic = np.zeros((num_components, dim), dtype=np.float64) | |
| def add(self, feats: np.ndarray, post: np.ndarray) -> None: | |
| gamma = post.sum(axis=0) | |
| self.counts += gamma | |
| self.linear += post.T @ feats | |
| self.quadratic += post.T @ (feats ** 2) | |
| class MonophoneModel: | |
| def __init__(self, symbols: list[str], dim: int, args: argparse.Namespace): | |
| self.symbols = symbols | |
| self.dim = dim | |
| self.args = args | |
| self.state_to_symbol = [] | |
| self.state_variants = [] | |
| self.symbol_to_states = {} | |
| self.states = [] | |
| for sym in symbols: | |
| indices = [] | |
| for _ in range(args.states_per_symbol): | |
| state_id = len(self.states) | |
| self.states.append(DiagGMM(dim, args)) | |
| self.state_to_symbol.append(sym) | |
| self.state_variants.append(len(indices)) | |
| indices.append(state_id) | |
| self.symbol_to_states[sym] = indices | |
| def expand(self, tokens: list[str]) -> np.ndarray: | |
| seq = [] | |
| for token in tokens: | |
| seq.extend(self.symbol_to_states[token]) | |
| return np.array(seq, dtype=np.int32) | |
| def num_gauss(self) -> int: | |
| return sum(state.weights.shape[0] for state in self.states) | |
| def equal_alignment(num_frames: int, state_sequence: np.ndarray) -> np.ndarray: | |
| repeat = max(1, num_frames // len(state_sequence)) | |
| aligned = np.repeat(state_sequence, repeat) | |
| if aligned.shape[0] >= num_frames: | |
| return aligned[:num_frames] | |
| tail = np.full(num_frames - aligned.shape[0], state_sequence[-1], dtype=np.int32) | |
| return np.concatenate([aligned, tail]) | |
| @njit(cache=True) | |
| def _viterbi_align_impl(emissions: np.ndarray, | |
| log_self: float, | |
| log_next: float) -> tuple[np.ndarray, float]: | |
| T, S = emissions.shape | |
| viterbi = np.empty((T, S), dtype=np.float64) | |
| backptr = np.zeros((T, S), dtype=np.int32) | |
| for s in range(S): | |
| viterbi[0, s] = -np.inf | |
| viterbi[0, 0] = emissions[0, 0] | |
| for t in range(1, T): | |
| for s in range(S): | |
| stay = viterbi[t - 1, s] + log_self | |
| best = stay | |
| prev_state = s | |
| if s > 0: | |
| advance = viterbi[t - 1, s - 1] + log_next | |
| if advance > best: | |
| best = advance | |
| prev_state = s - 1 | |
| viterbi[t, s] = best + emissions[t, s] | |
| backptr[t, s] = prev_state | |
| alignment = np.zeros(T, dtype=np.int32) | |
| state = S - 1 | |
| alignment[T - 1] = state | |
| for t in range(T - 1, 0, -1): | |
| state = backptr[t, state] | |
| alignment[t - 1] = state | |
| return alignment, viterbi[T - 1, S - 1] | |
| def viterbi_align(model: MonophoneModel, | |
| feats: np.ndarray, | |
| state_sequence: np.ndarray, | |
| return_score: bool = False) -> np.ndarray | tuple[np.ndarray, float]: | |
| S = len(state_sequence) | |
| T = feats.shape[0] | |
| emissions = np.zeros((T, S), dtype=np.float64) | |
| for idx, state_id in enumerate(state_sequence): | |
| log_like = model.states[state_id].log_likelihood(feats) | |
| emissions[:, idx] = logsumexp(log_like) | |
| log_self = math.log(0.6) | |
| log_next = math.log(0.4) | |
| alignment, score = _viterbi_align_impl(emissions, log_self, log_next) | |
| if not np.isfinite(score): | |
| fallback = equal_alignment(T, state_sequence) | |
| if return_score: | |
| return fallback, float("-inf") | |
| return fallback | |
| mapped = np.zeros_like(alignment) | |
| for t in range(T): | |
| mapped[t] = state_sequence[alignment[t]] | |
| if return_score: | |
| return mapped, float(score) | |
| return mapped | |
| def accumulate_stats(model: MonophoneModel, | |
| utterances: list[Utterance], | |
| alignments: dict[str, np.ndarray]) -> tuple[list[GMMAccumulator], float, int]: | |
| accumulators = [] | |
| for state in model.states: | |
| accumulators.append(GMMAccumulator(model.dim, state.weights.shape[0])) | |
| total_like = 0.0 | |
| total_frames = 0 | |
| for utt in utterances: | |
| alignment = alignments[utt.utt_id] | |
| feats = utt.features | |
| for state_id in np.unique(alignment): | |
| selector = alignment == state_id | |
| frames = feats[selector] | |
| if frames.shape[0] == 0: | |
| continue | |
| state = model.states[state_id] | |
| accumulator = accumulators[state_id] | |
| frame_log_like = state.accumulate(frames, accumulator) | |
| total_like += float(frame_log_like.sum()) | |
| total_frames += frames.shape[0] | |
| return accumulators, total_like, total_frames | |
| def update_model(model: MonophoneModel, | |
| accumulators: list[GMMAccumulator], | |
| min_count: float) -> np.ndarray: | |
| occs = np.zeros(len(model.states), dtype=np.float64) | |
| for idx, (state, acc) in enumerate(zip(model.states, accumulators)): | |
| occ = state.update(acc, min_count) | |
| if occ == 0.0: | |
| occ = acc.counts.sum() | |
| occs[idx] = occ | |
| return occs | |
| @dataclass(order=True) | |
| class SplitItem: | |
| priority: float | |
| index: int | |
| num_components: int | |
| occ_pow: float | |
| occ: float | |
| def compute_split_targets(state_occs: np.ndarray, | |
| target_total: int, | |
| power: float, | |
| min_count: float) -> list[int]: | |
| num_states = state_occs.shape[0] | |
| queue: list[SplitItem] = [] | |
| for idx in range(num_states): | |
| occ = state_occs[idx] | |
| occ_pow = occ ** power if occ > 0 else 0.0 | |
| priority = -occ_pow | |
| heapq.heappush(queue, SplitItem(priority, idx, 1, occ_pow, occ)) | |
| total = num_states | |
| while queue and total < target_total: | |
| item = heapq.heappop(queue) | |
| if item.occ_pow == 0: | |
| break | |
| if (item.num_components + 1) * min_count >= item.occ: | |
| item = SplitItem(0.0, item.index, item.num_components, 0.0, item.occ) | |
| else: | |
| item = SplitItem(-item.occ_pow / (item.num_components + 1), | |
| item.index, | |
| item.num_components + 1, | |
| item.occ_pow, | |
| item.occ) | |
| total += 1 | |
| heapq.heappush(queue, item) | |
| targets = [1] * num_states | |
| while queue: | |
| item = heapq.heappop(queue) | |
| targets[item.index] = item.num_components | |
| return targets | |
| def mixup(model: MonophoneModel, | |
| state_occs: np.ndarray, | |
| current_target: int, | |
| power: float, | |
| min_count: float, | |
| perturb: float) -> None: | |
| current_total = model.num_gauss() | |
| if current_target <= current_total: | |
| return | |
| targets = compute_split_targets(state_occs, current_target, power, min_count) | |
| for state, target in zip(model.states, targets): | |
| if target > state.weights.shape[0]: | |
| state.split_to(target, perturb) | |
| def save_model(model: MonophoneModel, args: argparse.Namespace) -> None: | |
| args.out_dir.mkdir(parents=True, exist_ok=True) | |
| np.savez(args.out_dir / "final_model.npz", | |
| symbols=np.array(model.symbols, dtype=object), | |
| state_to_symbol=np.array(model.state_to_symbol, dtype=object), | |
| weights=np.array([state.weights for state in model.states], dtype=object), | |
| means=np.array([state.means for state in model.states], dtype=object), | |
| variances=np.array([state.vars for state in model.states], dtype=object), | |
| frame_length=args.frame_length, | |
| frame_shift=args.frame_shift, | |
| num_ceps=args.num_ceps, | |
| states_per_symbol=args.states_per_symbol) | |
| def save_alignments(model: MonophoneModel, | |
| alignments: dict[str, np.ndarray], | |
| align_path: Path) -> None: | |
| mapping = model.state_to_symbol | |
| align_path.parent.mkdir(parents=True, exist_ok=True) | |
| with align_path.open("w", encoding="utf-8") as f: | |
| for utt_id, state_ids in sorted(alignments.items()): | |
| symbols = [ | |
| f"{mapping[state_id]}{subscript_number(model.state_variants[state_id] + 1)}" | |
| for state_id in state_ids | |
| ] | |
| f.write(" ".join([utt_id, *symbols])) | |
| f.write("\n") | |
| def initialize_model(utterances: list[Utterance], | |
| args: argparse.Namespace) -> tuple[MonophoneModel, dict[str, np.ndarray], float, int]: | |
| symbols = sorted({token for utt in utterances for token in utt.tokens}) | |
| dim = utterances[0].features.shape[1] | |
| model = MonophoneModel(symbols, dim, args) | |
| if args.init is not None: | |
| snapshot = load_snapshot(args.init) | |
| weights_list = snapshot["weights"] | |
| means_list = snapshot["means"] | |
| vars_list = snapshot["variances"] | |
| for state, weights, means, variances in zip(model.states, weights_list, means_list, vars_list): | |
| state.weights = np.asarray(weights, dtype=np.float32) | |
| state.means = np.asarray(means, dtype=np.float32) | |
| state.vars = np.asarray(variances, dtype=np.float32) | |
| state._refresh() | |
| alignments = {} | |
| for utt in utterances: | |
| sequence = model.expand(utt.tokens) | |
| alignment, _ = viterbi_align(model, utt.features, sequence, return_score=True) | |
| alignments[utt.utt_id] = alignment | |
| _, total_like, total_frames = accumulate_stats(model, utterances, alignments) | |
| return model, alignments, total_like, total_frames | |
| initial_alignments: dict[str, np.ndarray] = {} | |
| for utt in utterances: | |
| state_sequence = model.expand(utt.tokens) | |
| alignment = equal_alignment(utt.features.shape[0], state_sequence) | |
| initial_alignments[utt.utt_id] = alignment | |
| accumulators, total_like, total_frames = accumulate_stats(model, utterances, initial_alignments) | |
| update_model(model, accumulators, min_count=1.0) | |
| return model, initial_alignments, total_like, total_frames | |
| def train(args: argparse.Namespace) -> None: | |
| labels = load_transcriptions(args.text_path) | |
| utterances = collect_utterances(args, labels) | |
| init_start = perf_counter() | |
| model, alignments, init_like, init_frames = initialize_model(utterances, args) | |
| init_time = perf_counter() - init_start | |
| realign_iters = {int(x) for x in args.realign_iters.split()} | |
| num_gauss = model.num_gauss() | |
| inc_gauss = max((args.totgauss - num_gauss) // args.max_iter_inc, 1) | |
| current_target = num_gauss | |
| init_source = "snapshot" if args.init else "flat" | |
| if init_frames > 0: | |
| avg_like = init_like / init_frames | |
| print(f"[train_mono] init: avg loglike/frame={avg_like:.4f} " | |
| f"total_like={init_like:.2f} frames={init_frames} " | |
| f"num_gauss={model.num_gauss()} source={init_source} " | |
| f"| times: init={init_time:.2f}s", | |
| flush=True) | |
| else: | |
| print(f"[train_mono] init: frames=0 source={init_source} " | |
| f"| times: init={init_time:.2f}s", | |
| flush=True) | |
| for iteration in range(1, args.num_iters): | |
| iter_start = perf_counter() | |
| realign_time = 0.0 | |
| if iteration in realign_iters: | |
| print(f"[train_mono] iter {iteration:02d}: realigning", flush=True) | |
| realign_begin = perf_counter() | |
| new_alignments = {} | |
| for utt in utterances: | |
| sequence = model.expand(utt.tokens) | |
| new_alignments[utt.utt_id] = viterbi_align(model, utt.features, sequence) | |
| alignments = new_alignments | |
| realign_time = perf_counter() - realign_begin | |
| print(f"[train_mono] iter {iteration:02d}: realign done in {realign_time:.2f}s", | |
| flush=True) | |
| accum_start = perf_counter() | |
| accumulators, total_like, total_frames = accumulate_stats(model, utterances, alignments) | |
| accum_time = perf_counter() - accum_start | |
| update_start = perf_counter() | |
| state_occs = update_model(model, accumulators, args.mixup_min_count) | |
| update_time = perf_counter() - update_start | |
| num_gauss_before_mix = model.num_gauss() | |
| if total_frames > 0: | |
| avg_like = total_like / total_frames | |
| else: | |
| avg_like = float("nan") | |
| mix_time = 0.0 | |
| if iteration <= args.max_iter_inc: | |
| current_target = min(args.totgauss, current_target + inc_gauss) | |
| mix_start = perf_counter() | |
| mixup(model, state_occs, current_target, args.power, args.mixup_min_count, args.mixup_perturb) | |
| mix_time = perf_counter() - mix_start | |
| num_gauss_after_mix = model.num_gauss() | |
| if num_gauss_after_mix != num_gauss_before_mix: | |
| print(f"[train_mono] iter {iteration:02d}: mixup -> num_gauss {num_gauss_before_mix} -> {num_gauss_after_mix}", | |
| flush=True) | |
| iter_time = perf_counter() - iter_start | |
| print(f"[train_mono] iter {iteration:02d}: avg loglike/frame={avg_like:.4f} " | |
| f"total_like={total_like:.2f} frames={total_frames} " | |
| f"num_gauss={model.num_gauss()} " | |
| f"| times: realign={realign_time:.2f}s accum={accum_time:.2f}s " | |
| f"update={update_time:.2f}s mix={mix_time:.2f}s iter={iter_time:.2f}s", | |
| flush=True) | |
| snapshot = { | |
| "weights": [state.weights for state in model.states], | |
| "means": [state.means for state in model.states], | |
| "variances": [state.vars for state in model.states], | |
| } | |
| np.save(args.out_dir / f"model_iter_{iteration}.npy", | |
| np.array([snapshot], dtype=object), | |
| allow_pickle=True) | |
| save_alignments(model, alignments, args.out_dir / f"alignments_iter_{iteration:02d}.txt") | |
| save_alignments(model, alignments, args.out_dir / "alignments.txt") | |
| save_model(model, args) | |
| def main() -> None: | |
| args = parse_args() | |
| args.out_dir.mkdir(parents=True, exist_ok=True) | |
| train(args) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment