Skip to content

Instantly share code, notes, and snippets.

@proger
Last active October 26, 2025 23:50
Show Gist options
  • Select an option

  • Save proger/22667a00e6c1991396a481c3835edc0f to your computer and use it in GitHub Desktop.

Select an option

Save proger/22667a00e6c1991396a481c3835edc0f to your computer and use it in GitHub Desktop.
"""
Train a GMM and use it to Viterbi-align the unicode character labels of the training set to the audio.
Let wav/ contain a directory of 16-bit PCM wav files, text contain a list of filename-transcription pairs. Usage is:
$ cat > text << EOF
25153.wav — Я це доскона́ло зна́ю, а ось мене́ ду́же диву́є, як ви мо́жете про це зна́ти!
EOF
$ python -m train_mono wav/ text exp/
See also:
https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/steps/train_mono.sh
"""
import argparse
import heapq
import math
import unicodedata
import wave
from dataclasses import dataclass
from pathlib import Path
from time import perf_counter
import numpy as np
from numba import njit
@dataclass(frozen=True)
class Utterance:
utt_id: str
path: Path
tokens: list[str]
features: np.ndarray
def parse_args() -> argparse.Namespace:
class Formatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): pass
parser = argparse.ArgumentParser(description=__doc__, formatter_class=Formatter)
parser.add_argument("data_dir", type=Path, help="A directory with audio files")
parser.add_argument("text_path", type=Path, help="A transcript file formatted like: 'file.wav a brown fox jumped over the lazy dog' (first column is a file name, rest are words)")
parser.add_argument("out_dir", type=Path, help="An experiment directory where checkpoints and alignments will be stored")
parser.add_argument("--num-iters", type=int, default=40, help="Number of passes over the entire dataset")
parser.add_argument("--max-iter-inc", type=int, default=30, help="How many iterations to grow the model size for")
parser.add_argument("--totgauss", type=int, default=1000, help="Total target number of Gaussian components split between all states")
parser.add_argument("--power", type=float, default=0.25, help="Gaussian states grow proportionally to occupancy ** power")
parser.add_argument("--mixup-perturb", type=float, default=0.01, help="LBG-style Gaussian mean perturbation factor when growing the model")
parser.add_argument("--mixup-min-count", type=float, default=3.0, help="Minimum occupancy a state needs to have to split")
parser.add_argument("--realign-iters", type=str,
default="1 2 3 4 5 6 7 8 9 10 12 14 16 18 20 23 26 29 32 35 38",
help="Iteration indices where realignment should happen (from egs/wsj)")
parser.add_argument("--init", type=Path, help="Path to snapshot .npy file to initialize model parameters")
parser.add_argument("--frame-length", type=float, default=0.025, help="Frame length in seconds")
parser.add_argument("--frame-shift", type=float, default=0.010, help="Frame shift in seconds")
parser.add_argument("--states-per-symbol", type=int, default=3, help="Number of HMM states per symbol")
parser.add_argument("--num-ceps", type=int, default=13, help="Number of cepstral coefficients")
parser.add_argument("--num-mels", type=int, default=24, help="Number of mel bands")
parser.add_argument("--preemphasis", type=float, default=0.97, help="Pre-emphasis factor for waveform")
parser.add_argument("--var-floor", type=float, default=1e-3, help="Variance floor used during model updates")
return parser.parse_args()
def read_wave(path: Path) -> tuple[np.ndarray, int]:
with wave.open(str(path), "rb") as wf:
num_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
sample_rate = wf.getframerate()
num_frames = wf.getnframes()
assert sample_width == 2, "expect 16-bit PCM"
raw = wf.readframes(num_frames)
data = np.frombuffer(raw, dtype=np.int16).astype(np.float32)
if num_channels == 2:
data = data.reshape(-1, 2).mean(axis=1)
data /= 32768.0
return data, sample_rate
def frame_signal(signal: np.ndarray, sample_rate: int, args: argparse.Namespace) -> np.ndarray:
frame_len = int(round(args.frame_length * sample_rate))
frame_step = int(round(args.frame_shift * sample_rate))
num_frames = 1 + max(0, (len(signal) - frame_len) // frame_step)
pad_len = (num_frames - 1) * frame_step + frame_len
if pad_len > len(signal):
signal = np.pad(signal, (0, pad_len - len(signal)))
indices = np.arange(frame_len)[None, :] + frame_step * np.arange(num_frames)[:, None]
frames = signal[indices]
return frames
def mel_scale(freq: float | np.ndarray) -> np.ndarray:
freq_arr = np.asarray(freq, dtype=np.float64)
return 2595.0 * np.log10(1.0 + freq_arr / 700.0)
def inverse_mel(mel: float | np.ndarray) -> np.ndarray:
mel_arr = np.asarray(mel, dtype=np.float64)
return 700.0 * (10 ** (mel_arr / 2595.0) - 1)
def create_mel_filterbank(sample_rate: int, nfft: int, args: argparse.Namespace) -> np.ndarray:
low_mel = float(mel_scale(0.0))
high_mel = float(mel_scale(sample_rate / 2.0))
mel_points = np.linspace(low_mel, high_mel, args.num_mels + 2)
hz_points = inverse_mel(mel_points)
bin_points = np.floor((nfft + 1) * hz_points / sample_rate).astype(int)
filters = np.zeros((args.num_mels, nfft // 2 + 1), dtype=np.float32)
for m in range(1, args.num_mels + 1):
left = int(bin_points[m - 1])
center = int(bin_points[m])
right = int(bin_points[m + 1])
left = max(left, 0)
center = max(center, left + 1)
center = min(center, right)
right = min(right, nfft // 2)
if right <= left:
continue
for k in range(left, center):
filters[m - 1, k] = (k - left) / max(center - left, 1)
for k in range(center, right):
filters[m - 1, k] = (right - k) / max(right - center, 1)
return filters
def create_dct_matrix(num_mels: int, num_ceps: int) -> np.ndarray:
n = np.arange(num_mels)[None, :]
k = np.arange(num_ceps)[:, None]
mat = np.sqrt(2.0 / num_mels) * np.cos((np.pi / num_mels) * (n + 0.5) * k)
mat[0] *= np.sqrt(0.5)
return mat
def compute_mfcc(signal: np.ndarray, sample_rate: int, args: argparse.Namespace) -> np.ndarray:
emphasized = np.append(signal[0], signal[1:] - args.preemphasis * signal[:-1])
frames = frame_signal(emphasized, sample_rate, args)
frames *= np.hamming(frames.shape[1])[None, :]
nfft = 1
while nfft < frames.shape[1]:
nfft *= 2
spectrum = np.fft.rfft(frames, n=nfft)
power = (np.abs(spectrum) ** 2) / nfft
fbanks = create_mel_filterbank(sample_rate, nfft, args)
mel_energy = np.dot(power, fbanks.T)
mel_energy = np.maximum(mel_energy, 1e-10)
log_mel = np.log(mel_energy)
dct = create_dct_matrix(log_mel.shape[1], args.num_ceps)
mfcc = log_mel @ dct.T
return mfcc.astype(np.float32)
def compute_deltas(feats: np.ndarray, N: int = 2) -> np.ndarray:
denom = 2 * sum(d * d for d in range(1, N + 1))
padded = np.pad(feats, ((N, N), (0, 0)), mode="edge")
delta = np.zeros_like(feats)
for t in range(feats.shape[0]):
acc = np.zeros(feats.shape[1], dtype=np.float32)
for d in range(1, N + 1):
acc += d * (padded[t + N + d] - padded[t + N - d])
delta[t] = acc / denom
return delta
def logsumexp(values: np.ndarray) -> np.ndarray:
max_vals = np.max(values, axis=1, keepdims=True)
stabilized = np.exp(values - max_vals)
return (max_vals[:, 0] + np.log(np.sum(stabilized, axis=1))).astype(np.float64)
def apply_cmvn(feats: np.ndarray) -> np.ndarray:
mean = feats.mean(axis=0, keepdims=True)
return feats - mean
def subscript_number(number: int | float | str) -> str:
translation = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉")
return f"{number}".translate(translation)
def extract_features(path: Path, args: argparse.Namespace) -> np.ndarray:
signal, sample_rate = read_wave(path)
mfcc = compute_mfcc(signal, sample_rate, args)
mfcc = apply_cmvn(mfcc)
delta = compute_deltas(mfcc)
deltadelta = compute_deltas(delta)
return np.concatenate([mfcc, delta, deltadelta], axis=1)
def normalize_text(text: str) -> str:
cleaned = []
for ch in unicodedata.normalize("NFC", text.casefold()):
cat = uni_category(ch)
if cat.startswith("P") or ch in {"«", "»", "—", "–", "…", "’", "“", "”"}:
continue
cleaned.append(" " if ch.isspace() else ch)
return "".join(cleaned)
def uni_category(ch: str) -> str:
return unicodedata.category(ch)
def split_graphemes(text: str) -> list[str]:
clusters = []
current = ""
for ch in text:
if not current:
current = ch
elif unicodedata.combining(ch):
current += ch
else:
clusters.append(current)
current = ch
if current:
clusters.append(current)
return clusters
def load_transcriptions(text_path: Path) -> dict[str, list[str]]:
labels = {}
with text_path.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
key, content = line.split(maxsplit=1)
normalized = normalize_text(content.strip())
words = [w for w in normalized.split() if w]
tokens: list[str] = []
tokens.append("_")
for word in words:
graphemes = split_graphemes(word)
tokens.extend(graphemes)
tokens.append("_")
if not words:
tokens.append("_")
labels[key] = tokens
return labels
def collect_utterances(args: argparse.Namespace, labels: dict[str, list[str]]) -> list[Utterance]:
utterances = []
for wav_path in sorted(args.data_dir.glob("*.wav")):
utt_id = wav_path.name
if utt_id not in labels:
continue
feats = extract_features(wav_path, args)
if feats.shape[0] == 0:
continue
utterances.append(Utterance(utt_id=utt_id, path=wav_path,
tokens=labels[utt_id], features=feats))
return utterances
def load_snapshot(path: Path) -> dict:
data = np.load(path, allow_pickle=True)
if isinstance(data, dict):
return data
if data.shape == ():
return data.item()
return data[0]
class DiagGMM:
def __init__(self, dim: int, args: argparse.Namespace, num_components: int = 1):
self.dim = dim
self.args = args
self.weights = np.full(num_components, 1.0 / num_components, dtype=np.float32)
self.means = np.zeros((num_components, dim), dtype=np.float32)
self.vars = np.ones((num_components, dim), dtype=np.float32)
self._refresh()
def _refresh(self) -> None:
self.log_weights = np.log(self.weights + 1e-12).astype(np.float32)
log_two_pi = np.log(2 * np.pi)
self.log_norm = (-0.5 * (np.log(self.vars).sum(axis=1) + self.dim * log_two_pi)).astype(np.float32)
self.inv_vars = 1.0 / self.vars
self.mu_inv_vars = self.means * self.inv_vars
self.mu2_inv_vars_sum = np.sum(self.means * self.mu_inv_vars, axis=1)
def log_likelihood(self, feats: np.ndarray) -> np.ndarray:
term1 = feats ** 2 @ self.inv_vars.T
term2 = feats @ self.mu_inv_vars.T
exponent = 0.5 * (term1 - 2.0 * term2 + self.mu2_inv_vars_sum[None, :])
return self.log_weights + self.log_norm - exponent
def posterior_and_loglike(self, feats: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
if feats.shape[0] == 0:
return np.zeros((0, self.weights.shape[0]), dtype=np.float32), np.zeros(0, dtype=np.float32)
log_like = self.log_likelihood(feats)
max_log = np.max(log_like, axis=1, keepdims=True)
probs = np.exp(log_like - max_log)
probs_sum = probs.sum(axis=1, keepdims=True)
post = (probs / probs_sum).astype(np.float32)
frame_log_like = (max_log[:, 0] + np.log(probs_sum[:, 0])).astype(np.float32)
return post, frame_log_like
def accumulate(self, feats: np.ndarray, accumulator: "GMMAccumulator") -> np.ndarray:
post, frame_log_like = self.posterior_and_loglike(feats)
accumulator.add(feats, post)
return frame_log_like
def update(self, accumulator: "GMMAccumulator", min_count: float) -> float:
counts = accumulator.counts
total = counts.sum()
if total < min_count:
return 0.0
means = accumulator.linear / counts[:, None]
second = accumulator.quadratic / counts[:, None]
vars_ = second - means ** 2
vars_ = np.maximum(vars_, self.args.var_floor)
weights = counts / total
self.weights = weights
self.means = means
self.vars = vars_
self._refresh()
return total
def split_to(self, target_components: int, perturb_factor: float) -> None:
while self.weights.shape[0] < target_components:
idx = int(np.argmax(self.weights))
mean = self.means[idx]
var = self.vars[idx]
weight = self.weights[idx] * 0.5
perturb = perturb_factor * np.sqrt(var)
new_mean_pos = mean + perturb
new_mean_neg = mean - perturb
self.weights[idx] = weight
self.means[idx] = new_mean_pos
self.vars[idx] = var
self.weights = np.concatenate([self.weights, [weight]])
self.means = np.vstack([self.means, new_mean_neg])
self.vars = np.vstack([self.vars, var])
self._refresh()
class GMMAccumulator:
def __init__(self, dim: int, num_components: int):
self.counts = np.zeros(num_components, dtype=np.float64)
self.linear = np.zeros((num_components, dim), dtype=np.float64)
self.quadratic = np.zeros((num_components, dim), dtype=np.float64)
def add(self, feats: np.ndarray, post: np.ndarray) -> None:
gamma = post.sum(axis=0)
self.counts += gamma
self.linear += post.T @ feats
self.quadratic += post.T @ (feats ** 2)
class MonophoneModel:
def __init__(self, symbols: list[str], dim: int, args: argparse.Namespace):
self.symbols = symbols
self.dim = dim
self.args = args
self.state_to_symbol = []
self.state_variants = []
self.symbol_to_states = {}
self.states = []
for sym in symbols:
indices = []
for _ in range(args.states_per_symbol):
state_id = len(self.states)
self.states.append(DiagGMM(dim, args))
self.state_to_symbol.append(sym)
self.state_variants.append(len(indices))
indices.append(state_id)
self.symbol_to_states[sym] = indices
def expand(self, tokens: list[str]) -> np.ndarray:
seq = []
for token in tokens:
seq.extend(self.symbol_to_states[token])
return np.array(seq, dtype=np.int32)
def num_gauss(self) -> int:
return sum(state.weights.shape[0] for state in self.states)
def equal_alignment(num_frames: int, state_sequence: np.ndarray) -> np.ndarray:
repeat = max(1, num_frames // len(state_sequence))
aligned = np.repeat(state_sequence, repeat)
if aligned.shape[0] >= num_frames:
return aligned[:num_frames]
tail = np.full(num_frames - aligned.shape[0], state_sequence[-1], dtype=np.int32)
return np.concatenate([aligned, tail])
@njit(cache=True)
def _viterbi_align_impl(emissions: np.ndarray,
log_self: float,
log_next: float) -> tuple[np.ndarray, float]:
T, S = emissions.shape
viterbi = np.empty((T, S), dtype=np.float64)
backptr = np.zeros((T, S), dtype=np.int32)
for s in range(S):
viterbi[0, s] = -np.inf
viterbi[0, 0] = emissions[0, 0]
for t in range(1, T):
for s in range(S):
stay = viterbi[t - 1, s] + log_self
best = stay
prev_state = s
if s > 0:
advance = viterbi[t - 1, s - 1] + log_next
if advance > best:
best = advance
prev_state = s - 1
viterbi[t, s] = best + emissions[t, s]
backptr[t, s] = prev_state
alignment = np.zeros(T, dtype=np.int32)
state = S - 1
alignment[T - 1] = state
for t in range(T - 1, 0, -1):
state = backptr[t, state]
alignment[t - 1] = state
return alignment, viterbi[T - 1, S - 1]
def viterbi_align(model: MonophoneModel,
feats: np.ndarray,
state_sequence: np.ndarray,
return_score: bool = False) -> np.ndarray | tuple[np.ndarray, float]:
S = len(state_sequence)
T = feats.shape[0]
emissions = np.zeros((T, S), dtype=np.float64)
for idx, state_id in enumerate(state_sequence):
log_like = model.states[state_id].log_likelihood(feats)
emissions[:, idx] = logsumexp(log_like)
log_self = math.log(0.6)
log_next = math.log(0.4)
alignment, score = _viterbi_align_impl(emissions, log_self, log_next)
if not np.isfinite(score):
fallback = equal_alignment(T, state_sequence)
if return_score:
return fallback, float("-inf")
return fallback
mapped = np.zeros_like(alignment)
for t in range(T):
mapped[t] = state_sequence[alignment[t]]
if return_score:
return mapped, float(score)
return mapped
def accumulate_stats(model: MonophoneModel,
utterances: list[Utterance],
alignments: dict[str, np.ndarray]) -> tuple[list[GMMAccumulator], float, int]:
accumulators = []
for state in model.states:
accumulators.append(GMMAccumulator(model.dim, state.weights.shape[0]))
total_like = 0.0
total_frames = 0
for utt in utterances:
alignment = alignments[utt.utt_id]
feats = utt.features
for state_id in np.unique(alignment):
selector = alignment == state_id
frames = feats[selector]
if frames.shape[0] == 0:
continue
state = model.states[state_id]
accumulator = accumulators[state_id]
frame_log_like = state.accumulate(frames, accumulator)
total_like += float(frame_log_like.sum())
total_frames += frames.shape[0]
return accumulators, total_like, total_frames
def update_model(model: MonophoneModel,
accumulators: list[GMMAccumulator],
min_count: float) -> np.ndarray:
occs = np.zeros(len(model.states), dtype=np.float64)
for idx, (state, acc) in enumerate(zip(model.states, accumulators)):
occ = state.update(acc, min_count)
if occ == 0.0:
occ = acc.counts.sum()
occs[idx] = occ
return occs
@dataclass(order=True)
class SplitItem:
priority: float
index: int
num_components: int
occ_pow: float
occ: float
def compute_split_targets(state_occs: np.ndarray,
target_total: int,
power: float,
min_count: float) -> list[int]:
num_states = state_occs.shape[0]
queue: list[SplitItem] = []
for idx in range(num_states):
occ = state_occs[idx]
occ_pow = occ ** power if occ > 0 else 0.0
priority = -occ_pow
heapq.heappush(queue, SplitItem(priority, idx, 1, occ_pow, occ))
total = num_states
while queue and total < target_total:
item = heapq.heappop(queue)
if item.occ_pow == 0:
break
if (item.num_components + 1) * min_count >= item.occ:
item = SplitItem(0.0, item.index, item.num_components, 0.0, item.occ)
else:
item = SplitItem(-item.occ_pow / (item.num_components + 1),
item.index,
item.num_components + 1,
item.occ_pow,
item.occ)
total += 1
heapq.heappush(queue, item)
targets = [1] * num_states
while queue:
item = heapq.heappop(queue)
targets[item.index] = item.num_components
return targets
def mixup(model: MonophoneModel,
state_occs: np.ndarray,
current_target: int,
power: float,
min_count: float,
perturb: float) -> None:
current_total = model.num_gauss()
if current_target <= current_total:
return
targets = compute_split_targets(state_occs, current_target, power, min_count)
for state, target in zip(model.states, targets):
if target > state.weights.shape[0]:
state.split_to(target, perturb)
def save_model(model: MonophoneModel, args: argparse.Namespace) -> None:
args.out_dir.mkdir(parents=True, exist_ok=True)
np.savez(args.out_dir / "final_model.npz",
symbols=np.array(model.symbols, dtype=object),
state_to_symbol=np.array(model.state_to_symbol, dtype=object),
weights=np.array([state.weights for state in model.states], dtype=object),
means=np.array([state.means for state in model.states], dtype=object),
variances=np.array([state.vars for state in model.states], dtype=object),
frame_length=args.frame_length,
frame_shift=args.frame_shift,
num_ceps=args.num_ceps,
states_per_symbol=args.states_per_symbol)
def save_alignments(model: MonophoneModel,
alignments: dict[str, np.ndarray],
align_path: Path) -> None:
mapping = model.state_to_symbol
align_path.parent.mkdir(parents=True, exist_ok=True)
with align_path.open("w", encoding="utf-8") as f:
for utt_id, state_ids in sorted(alignments.items()):
symbols = [
f"{mapping[state_id]}{subscript_number(model.state_variants[state_id] + 1)}"
for state_id in state_ids
]
f.write(" ".join([utt_id, *symbols]))
f.write("\n")
def initialize_model(utterances: list[Utterance],
args: argparse.Namespace) -> tuple[MonophoneModel, dict[str, np.ndarray], float, int]:
symbols = sorted({token for utt in utterances for token in utt.tokens})
dim = utterances[0].features.shape[1]
model = MonophoneModel(symbols, dim, args)
if args.init is not None:
snapshot = load_snapshot(args.init)
weights_list = snapshot["weights"]
means_list = snapshot["means"]
vars_list = snapshot["variances"]
for state, weights, means, variances in zip(model.states, weights_list, means_list, vars_list):
state.weights = np.asarray(weights, dtype=np.float32)
state.means = np.asarray(means, dtype=np.float32)
state.vars = np.asarray(variances, dtype=np.float32)
state._refresh()
alignments = {}
for utt in utterances:
sequence = model.expand(utt.tokens)
alignment, _ = viterbi_align(model, utt.features, sequence, return_score=True)
alignments[utt.utt_id] = alignment
_, total_like, total_frames = accumulate_stats(model, utterances, alignments)
return model, alignments, total_like, total_frames
initial_alignments: dict[str, np.ndarray] = {}
for utt in utterances:
state_sequence = model.expand(utt.tokens)
alignment = equal_alignment(utt.features.shape[0], state_sequence)
initial_alignments[utt.utt_id] = alignment
accumulators, total_like, total_frames = accumulate_stats(model, utterances, initial_alignments)
update_model(model, accumulators, min_count=1.0)
return model, initial_alignments, total_like, total_frames
def train(args: argparse.Namespace) -> None:
labels = load_transcriptions(args.text_path)
utterances = collect_utterances(args, labels)
init_start = perf_counter()
model, alignments, init_like, init_frames = initialize_model(utterances, args)
init_time = perf_counter() - init_start
realign_iters = {int(x) for x in args.realign_iters.split()}
num_gauss = model.num_gauss()
inc_gauss = max((args.totgauss - num_gauss) // args.max_iter_inc, 1)
current_target = num_gauss
init_source = "snapshot" if args.init else "flat"
if init_frames > 0:
avg_like = init_like / init_frames
print(f"[train_mono] init: avg loglike/frame={avg_like:.4f} "
f"total_like={init_like:.2f} frames={init_frames} "
f"num_gauss={model.num_gauss()} source={init_source} "
f"| times: init={init_time:.2f}s",
flush=True)
else:
print(f"[train_mono] init: frames=0 source={init_source} "
f"| times: init={init_time:.2f}s",
flush=True)
for iteration in range(1, args.num_iters):
iter_start = perf_counter()
realign_time = 0.0
if iteration in realign_iters:
print(f"[train_mono] iter {iteration:02d}: realigning", flush=True)
realign_begin = perf_counter()
new_alignments = {}
for utt in utterances:
sequence = model.expand(utt.tokens)
new_alignments[utt.utt_id] = viterbi_align(model, utt.features, sequence)
alignments = new_alignments
realign_time = perf_counter() - realign_begin
print(f"[train_mono] iter {iteration:02d}: realign done in {realign_time:.2f}s",
flush=True)
accum_start = perf_counter()
accumulators, total_like, total_frames = accumulate_stats(model, utterances, alignments)
accum_time = perf_counter() - accum_start
update_start = perf_counter()
state_occs = update_model(model, accumulators, args.mixup_min_count)
update_time = perf_counter() - update_start
num_gauss_before_mix = model.num_gauss()
if total_frames > 0:
avg_like = total_like / total_frames
else:
avg_like = float("nan")
mix_time = 0.0
if iteration <= args.max_iter_inc:
current_target = min(args.totgauss, current_target + inc_gauss)
mix_start = perf_counter()
mixup(model, state_occs, current_target, args.power, args.mixup_min_count, args.mixup_perturb)
mix_time = perf_counter() - mix_start
num_gauss_after_mix = model.num_gauss()
if num_gauss_after_mix != num_gauss_before_mix:
print(f"[train_mono] iter {iteration:02d}: mixup -> num_gauss {num_gauss_before_mix} -> {num_gauss_after_mix}",
flush=True)
iter_time = perf_counter() - iter_start
print(f"[train_mono] iter {iteration:02d}: avg loglike/frame={avg_like:.4f} "
f"total_like={total_like:.2f} frames={total_frames} "
f"num_gauss={model.num_gauss()} "
f"| times: realign={realign_time:.2f}s accum={accum_time:.2f}s "
f"update={update_time:.2f}s mix={mix_time:.2f}s iter={iter_time:.2f}s",
flush=True)
snapshot = {
"weights": [state.weights for state in model.states],
"means": [state.means for state in model.states],
"variances": [state.vars for state in model.states],
}
np.save(args.out_dir / f"model_iter_{iteration}.npy",
np.array([snapshot], dtype=object),
allow_pickle=True)
save_alignments(model, alignments, args.out_dir / f"alignments_iter_{iteration:02d}.txt")
save_alignments(model, alignments, args.out_dir / "alignments.txt")
save_model(model, args)
def main() -> None:
args = parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
train(args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment