Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Mat-KH/8374cadb5e8b8f343671ddf3603813aa to your computer and use it in GitHub Desktop.
Save Mat-KH/8374cadb5e8b8f343671ddf3603813aa to your computer and use it in GitHub Desktop.
speech_to_text_streaming_infer_rnnt_timestamps_individual
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Streaming / buffered RNNT inference with fast word-level timestamps including word text.
Modified from NVIDIA script in Nemo ASR examples:
https://github.com/NVIDIA-NeMo/NeMo/blob/main/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py
Based on Issue:
https://github.com/NVIDIA-NeMo/NeMo/issues/14714
Key points:
- Maintains streaming throughput (no switch to high-level transcribe()).
- Emits only "pred_text" and "word" timestamps (list of {start, end, word} in seconds).
- Derives word boundaries from tokenizer tokens:
- SentencePiece: tokens containing '▁' are word starts.
- WordPiece: tokens without '##' prefix are word starts.
- Fallback: treat each token as a word.
- Converts encoder-step offsets to seconds using the model's true stride (secs_per_step).
- Does NOT use normalize_timestamp_output() for synthesized word timings (avoids 0.1 s/step assumption).
Modified behavior:
- Always writes one pretty-printed JSON per input audio file, saved next to the audio,
using the same basename with a .json extension (e.g., sample.wav -> sample.json).
- No aggregate manifest (audio.json) is created.
- WER computation removed for simplicity.
"""
import copy
import glob
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Any, List, Tuple
import lightning.pytorch as pl
import torch
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
GreedyBatchedLabelLoopingComputerBase,
)
from nemo.collections.asr.parts.utils.manifest_utils import filepath_to_absolute, read_manifest
from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses
from nemo.collections.asr.parts.utils.streaming_utils import (
AudioBatch,
ContextSize,
SimpleAudioDataset,
StreamingBatchedAudioBuffer,
)
from nemo.collections.asr.parts.utils.transcribe_utils import (
setup_model,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
def make_divisible_by(num, factor: int) -> int:
return (num // factor) * factor
def _to_serializable(obj: Any) -> Any:
try:
import numpy as np
except Exception:
np = None
if torch.is_tensor(obj):
if obj.ndim == 0:
return obj.item()
return _to_serializable(obj.detach().cpu().tolist())
if np is not None and isinstance(obj, np.generic):
return obj.item()
if np is not None and isinstance(obj, np.ndarray):
return _to_serializable(obj.tolist())
if isinstance(obj, dict):
return {k: _to_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_to_serializable(v) for v in obj]
return obj
def _as_list_of_ints(x: Any) -> Optional[List[int]]:
if x is None:
return None
if torch.is_tensor(x):
x = x.detach().cpu()
if x.ndim == 0:
return [int(x.item())]
return [int(v) for v in x.tolist()]
try:
import numpy as np
except Exception:
np = None
if np is not None and isinstance(x, np.ndarray):
if x.ndim == 0:
return [int(x.item())]
return [int(v) for v in x.tolist()]
if isinstance(x, (list, tuple)):
out = []
for v in x:
try:
out.append(int(v))
except Exception:
return None
return out
try:
return [int(x)]
except Exception:
return None
def _group_tokens_into_words(tokens: List[str]) -> List[Tuple[int, int]]:
"""
Return list of (start_idx, end_idx_inclusive) token spans representing words.
Heuristics:
- SentencePiece: token containing '▁' starts a new word.
- WordPiece: token starting with '##' is a continuation; otherwise starts a new word.
- Fallback: each token is a word.
"""
n = len(tokens)
if n == 0:
return []
has_sp_marker = any('▁' in t for t in tokens)
has_wp_cont = any(t.startswith('##') for t in tokens)
spans: List[Tuple[int, int]] = []
if has_sp_marker:
start = 0
for i in range(1, n):
if '▁' in tokens[i]:
spans.append((start, i - 1))
start = i
spans.append((start, n - 1))
return spans
if has_wp_cont:
start = 0
for i in range(1, n):
if not tokens[i].startswith('##'):
spans.append((start, i - 1))
start = i
spans.append((start, n - 1))
return spans
# Fallback: every token is its own "word"
spans = [(i, i) for i in range(n)]
return spans
def _tokens_to_word(tokens: List[str], start: int, end: int) -> str:
"""
Build a clean word string from token pieces in [start, end].
- SentencePiece: drop '▁' markers and concatenate.
- WordPiece: drop '##' prefixes and concatenate.
"""
span = tokens[start : end + 1]
if not span:
return ""
if any('▁' in t for t in span): # SentencePiece style
pieces = [t.replace('▁', '') for t in span]
return "".join(pieces)
# WordPiece style: remove leading '##'
pieces = [t[2:] if t.startswith('##') else t for t in span]
return "".join(pieces)
def _convert_word_offsets_list_to_seconds(
word_list: List[dict],
secs_per_step: float,
latency_comp: float = 0.0,
) -> List[dict]:
"""
Convert a provided 'word' list that might be in step units
(keys 'start_offset'/'end_offset') into seconds ('start'/'end').
If entries already have 'start'/'end' in seconds, preserve them.
"""
out: List[dict] = []
for w in word_list:
if "start" in w and "end" in w:
start_sec = float(w["start"])
end_sec = float(w["end"])
if latency_comp:
start_sec = max(0.0, start_sec - latency_comp)
end_sec = max(start_sec, end_sec - latency_comp)
out.append({"start": start_sec, "end": end_sec, "word": w.get("word", w.get("text", ""))})
elif "start_offset" in w and "end_offset" in w:
s = int(w["start_offset"])
e = int(w["end_offset"])
# inclusive -> exclusive end boundary (+1)
start_sec = s * secs_per_step
end_sec = (e + 1) * secs_per_step
if latency_comp:
start_sec = max(0.0, start_sec - latency_comp)
end_sec = max(start_sec, end_sec - latency_comp)
out.append({"start": float(start_sec), "end": float(end_sec), "word": w.get("word", w.get("text", ""))})
else:
# Unknown structure; best-effort pass-through
out.append({**w})
return out
def write_transcription_custom(
transcriptions,
cfg,
model_name: str,
filepaths: Optional[List[str]] = None,
compute_langs: bool = False,
enable_timestamps: bool = False,
tokenizer=None,
secs_per_step: Optional[float] = None,
latency_comp: float = 0.0,
):
"""
Minimal, fast writer (per-file JSON):
- Writes pred_text and word timestamps ONLY.
- For timestamps:
- If hyp.timestamp contains 'word': write seconds directly if available;
if offsets, convert to seconds using secs_per_step.
- Else: synthesize words from tokenizer tokens + timestep list and convert to seconds.
- Always writes one pretty-printed JSON per audio file, next to the audio, with .json extension.
- Respects cfg.overwrite_transcripts per file.
"""
if cfg.append_pred:
logging.info('Per-file transcripts will include an appended prediction field name.')
pred_by_model_name = cfg.pred_name_postfix if cfg.pred_name_postfix is not None else model_name
pred_text_attr_name = 'pred_text_' + pred_by_model_name
else:
pred_text_attr_name = 'pred_text'
# Build pairs of (manifest_item, audio_fp) to align with transcriptions order
if cfg.audio_dir is not None:
assert filepaths is not None, "filepaths must be provided when using audio_dir"
pairs = [(None, fp) for fp in filepaths]
else:
pairs = []
with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr:
for line in fr:
line = line.strip()
if not line:
continue
pairs.append((json.loads(line), None))
# Determine structure of transcriptions
if isinstance(transcriptions[0], str):
best_hyps = transcriptions
return_hypotheses = False
beams = None
elif hasattr(transcriptions[0], "text"):
best_hyps = transcriptions
return_hypotheses = True
beams = None
elif isinstance(transcriptions[0], list) and hasattr(transcriptions[0][0], "text"):
best_hyps, beams = [], []
for hyps in transcriptions:
best_hyps.append(hyps[0])
if not cfg.decoding.beam.return_best_hypothesis:
beam = []
for hyp in hyps:
score = hyp.score.numpy().item() if isinstance(hyp.score, torch.Tensor) else hyp.score
beam.append((hyp.text, score))
beams.append(beam)
return_hypotheses = True
else:
raise TypeError("Unexpected transcription structure")
written_paths: List[str] = []
for idx, (manifest_item, audio_fp) in enumerate(pairs):
# Construct output JSON object for this item
if manifest_item is None:
# audio_dir mode
if not return_hypotheses:
item = {'audio_filepath': audio_fp, pred_text_attr_name: best_hyps[idx]}
else:
hyp = best_hyps[idx]
item = {'audio_filepath': audio_fp, pred_text_attr_name: hyp.text}
if enable_timestamps:
timestamp_data = getattr(hyp, "timestamp", None)
# Case 1: model already provides word-level dict
if isinstance(timestamp_data, dict) and "word" in timestamp_data:
word_list = timestamp_data["word"]
if secs_per_step is not None:
try:
item['word'] = _to_serializable(
_convert_word_offsets_list_to_seconds(word_list, secs_per_step, latency_comp)
)
except Exception as e:
logging.warning(f"Failed converting provided 'word' offsets to seconds: {e}. Writing raw.")
item['word'] = _to_serializable(word_list)
else:
item['word'] = _to_serializable(word_list)
else:
# Case 2: synthesize from tokens + timestep
token_ids = hyp.y_sequence.tolist() if hasattr(hyp, "y_sequence") else []
if tokenizer is not None and token_ids:
try:
tokens = tokenizer.ids_to_tokens(token_ids)
except Exception:
tokens = []
else:
tokens = []
timestep_list = None
if isinstance(timestamp_data, dict):
timestep_list = _as_list_of_ints(timestamp_data.get('timestep', None))
else:
timestep_list = _as_list_of_ints(timestamp_data)
if tokens and timestep_list:
if len(timestep_list) != len(tokens):
m = min(len(timestep_list), len(tokens))
tokens = tokens[:m]
timestep_list = timestep_list[:m]
word_spans = _group_tokens_into_words(tokens)
word_entries = []
for (s, e) in word_spans:
if 0 <= s < len(timestep_list) and 0 <= e < len(timestep_list) and s <= e:
word_text = _tokens_to_word(tokens, s, e)
if word_text == "":
continue
if secs_per_step is None:
# fallback to raw offsets if stride unknown
word_entries.append({
"start_offset": int(timestep_list[s]),
"end_offset": int(timestep_list[e]),
"word": word_text
})
else:
start_sec = float(timestep_list[s]) * secs_per_step
end_sec = float(timestep_list[e] + 1) * secs_per_step # inclusive->exclusive
if latency_comp:
start_sec = max(0.0, start_sec - latency_comp)
end_sec = max(start_sec, end_sec - latency_comp)
word_entries.append({
"start": start_sec,
"end": end_sec,
"word": word_text
})
if word_entries:
item['word'] = _to_serializable(word_entries)
out_path = Path(audio_fp).with_suffix(".json")
else:
# dataset_manifest mode
if not return_hypotheses:
item = dict(manifest_item)
item[pred_text_attr_name] = best_hyps[idx]
else:
hyp = best_hyps[idx]
item = dict(manifest_item)
item[pred_text_attr_name] = hyp.text
if enable_timestamps:
timestamp_data = getattr(hyp, "timestamp", None)
if isinstance(timestamp_data, dict) and "word" in timestamp_data:
word_list = timestamp_data["word"]
if secs_per_step is not None:
try:
item['word'] = _to_serializable(
_convert_word_offsets_list_to_seconds(word_list, secs_per_step, latency_comp)
)
except Exception as e:
logging.warning(f"Failed converting provided 'word' offsets to seconds: {e}. Writing raw.")
item['word'] = _to_serializable(word_list)
else:
item['word'] = _to_serializable(word_list)
else:
token_ids = hyp.y_sequence.tolist() if hasattr(hyp, "y_sequence") else []
if tokenizer is not None and token_ids:
try:
tokens = tokenizer.ids_to_tokens(token_ids)
except Exception:
tokens = []
else:
tokens = []
timestep_list = None
if isinstance(timestamp_data, dict):
timestep_list = _as_list_of_ints(timestamp_data.get('timestep', None))
else:
timestep_list = _as_list_of_ints(timestamp_data)
if tokens and timestep_list:
if len(timestep_list) != len(tokens):
m = min(len(timestep_list), len(tokens))
tokens = tokens[:m]
timestep_list = timestep_list[:m]
word_spans = _group_tokens_into_words(tokens)
word_entries = []
for (s, e) in word_spans:
if 0 <= s < len(timestep_list) and 0 <= e < len(timestep_list) and s <= e:
word_text = _tokens_to_word(tokens, s, e)
if word_text == "":
continue
if secs_per_step is None:
word_entries.append({
"start_offset": int(timestep_list[s]),
"end_offset": int(timestep_list[e]),
"word": word_text
})
else:
start_sec = float(timestep_list[s]) * secs_per_step
end_sec = float(timestep_list[e] + 1) * secs_per_step # inclusive->exclusive
if latency_comp:
start_sec = max(0.0, start_sec - latency_comp)
end_sec = max(start_sec, end_sec - latency_comp)
word_entries.append({
"start": start_sec,
"end": end_sec,
"word": word_text
})
if word_entries:
item['word'] = _to_serializable(word_entries)
out_path = Path(item["audio_filepath"]).with_suffix(".json")
# Ensure parent exists and write pretty JSON (indent=2)
out_path.parent.mkdir(parents=True, exist_ok=True)
if (not cfg.overwrite_transcripts) and out_path.exists():
logging.info(f"Skipping existing transcript (overwrite_transcripts=False): {out_path}")
else:
with open(out_path, 'w', encoding='utf-8', newline='\n') as fout:
json.dump(item, fout, ensure_ascii=False, indent=2)
fout.write("\n")
written_paths.append(str(out_path))
return written_paths, pred_text_attr_name
@dataclass
class TranscriptionConfig:
model_path: Optional[str] = None
pretrained_name: Optional[str] = None
audio_dir: Optional[str] = None
dataset_manifest: Optional[str] = None
# Output manifest fields removed; per-file JSON only
batch_size: int = 32
num_workers: int = 0
append_pred: bool = False
pred_name_postfix: Optional[str] = None
random_seed: Optional[int] = None
chunk_secs: float = 2
left_context_secs: float = 10.0
right_context_secs: float = 2
cuda: Optional[int] = None
allow_mps: bool = True
compute_dtype: Optional[str] = None
matmul_precision: str = "high"
audio_type: str = "wav"
overwrite_transcripts: bool = True
decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig)
@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision(cfg.matmul_precision)
cfg = OmegaConf.structured(cfg)
if cfg.random_seed:
pl.seed_everything(cfg.random_seed)
if cfg.model_path is None and cfg.pretrained_name is None:
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
if cfg.audio_dir is None and cfg.dataset_manifest is None:
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
filepaths = None
manifest = cfg.dataset_manifest
if cfg.audio_dir is not None:
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
manifest = None
# setup device
if cfg.cuda is None:
if torch.cuda.is_available():
map_location = torch.device('cuda:0')
elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logging.warning(
"MPS device support is experimental. Set PYTORCH_ENABLE_MPS_FALLBACK=1 to avoid failures."
)
map_location = torch.device('mps')
else:
map_location = torch.device('cpu')
elif cfg.cuda < 0:
map_location = torch.device('cpu')
else:
map_location = torch.device(f'cuda:{cfg.cuda}')
if cfg.compute_dtype is None:
can_use_bfloat16 = map_location.type == "cuda" and torch.cuda.is_bf16_supported()
compute_dtype = torch.bfloat16 if can_use_bfloat16 else torch.float32
else:
assert cfg.compute_dtype in {"float32", "bfloat16", "float16"}
compute_dtype = getattr(torch, cfg.compute_dtype)
logging.info(f"Inference device: {map_location}, dtype: {compute_dtype}")
asr_model, model_name = setup_model(cfg, map_location)
model_cfg = copy.deepcopy(asr_model._cfg)
OmegaConf.set_struct(model_cfg.preprocessor, False)
model_cfg.preprocessor.dither = 0.0
model_cfg.preprocessor.pad_to = 0
if model_cfg.preprocessor.normalize != "per_feature":
logging.error("Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently")
OmegaConf.set_struct(model_cfg.preprocessor, True)
# No aggregate output file is computed or used.
asr_model.freeze()
asr_model = asr_model.to(asr_model.device)
asr_model.to(compute_dtype)
with open_dict(cfg.decoding):
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
raise NotImplementedError("This script supports only `greedy_batch` with Label-Looping algorithm")
cfg.decoding.preserve_alignments = True
cfg.decoding.fused_batch_size = -1
cfg.decoding.beam.return_best_hypothesis = True
# Intentionally NOT setting compute_timestamps=True; we synthesize word times ourselves.
if hasattr(asr_model, 'change_decoding_strategy'):
if not isinstance(asr_model, EncDecRNNTModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
raise ValueError("This script supports RNNT model and hybrid RNNT-CTC model with RNNT decoder")
if isinstance(asr_model, EncDecRNNTModel):
asr_model.change_decoding_strategy(cfg.decoding)
if hasattr(asr_model, 'cur_decoder'):
asr_model.change_decoding_strategy(cfg.decoding, decoder_type='rnnt')
if manifest is not None:
records = read_manifest(manifest)
manifest_dir = Path(manifest).parent.absolute()
for record in records:
record["audio_filepath"] = str(filepath_to_absolute(record["audio_filepath"], manifest_dir))
else:
assert filepaths is not None
records = [{"audio_filepath": audio_file} for audio_file in filepaths]
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
audio_sample_rate = model_cfg.preprocessor['sample_rate']
feature_stride_sec = model_cfg.preprocessor['window_stride']
features_per_sec = 1.0 / feature_stride_sec
encoder_subsampling_factor = asr_model.encoder.subsampling_factor
features_frame2audio_samples = make_divisible_by(
int(audio_sample_rate * feature_stride_sec), factor=encoder_subsampling_factor
)
encoder_frame2audio_samples = features_frame2audio_samples * encoder_subsampling_factor
# Accurate seconds per encoder step (accounts for divisibility correction)
secs_per_step = encoder_frame2audio_samples / audio_sample_rate
# Optional UI latency compensation (seconds). Set to cfg.right_context_secs if desired.
latency_comp = 0.0
context_encoder_frames = ContextSize(
left=int(cfg.left_context_secs * features_per_sec / encoder_subsampling_factor),
chunk=int(cfg.chunk_secs * features_per_sec / encoder_subsampling_factor),
right=int(cfg.right_context_secs * features_per_sec / encoder_subsampling_factor),
)
context_samples = ContextSize(
left=context_encoder_frames.left * encoder_subsampling_factor * features_frame2audio_samples,
chunk=context_encoder_frames.chunk * encoder_subsampling_factor * features_frame2audio_samples,
right=context_encoder_frames.right * encoder_subsampling_factor * features_frame2audio_samples,
)
logging.info(
"Corrected contexts (sec): "
f"Left {context_samples.left / audio_sample_rate:.2f}, "
f"Chunk {context_samples.chunk / audio_sample_rate:.2f}, "
f"Right {context_samples.right / audio_sample_rate:.2f}"
)
logging.info(
f"Corrected contexts (subsampled encoder frames): Left {context_encoder_frames.left} - "
f"Chunk {context_encoder_frames.chunk} - Right {context_encoder_frames.right}"
)
logging.info(
f"Corrected contexts (in audio samples): Left {context_samples.left} - "
f"Chunk {context_samples.chunk} - Right {context_samples.right}"
)
latency_secs = (context_samples.chunk + context_samples.right) / audio_sample_rate
logging.info(f"Theoretical latency: {latency_secs:.2f} seconds")
logging.info(f"secs_per_step (encoder): {secs_per_step:.6f} s")
audio_dataset = SimpleAudioDataset(
audio_filenames=[record["audio_filepath"] for record in records], sample_rate=audio_sample_rate
)
audio_dataloader = DataLoader(
dataset=audio_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
collate_fn=AudioBatch.collate_fn,
drop_last=False,
in_order=True,
)
with torch.no_grad(), torch.inference_mode():
all_hyps = []
for audio_data in tqdm(audio_dataloader):
audio_batch = audio_data.audio_signals.to(device=map_location)
audio_batch_lengths = audio_data.audio_signal_lengths.to(device=map_location)
batch_size = audio_batch.shape[0]
device = audio_batch.device
current_batched_hyps: BatchedHyps | None = None
state = None
left_sample = 0
right_sample = min(context_samples.chunk + context_samples.right, audio_batch.shape[1])
buffer = StreamingBatchedAudioBuffer(
batch_size=batch_size,
context_samples=context_samples,
dtype=audio_batch.dtype,
device=device,
)
rest_audio_lengths = audio_batch_lengths.clone()
while left_sample < audio_batch.shape[1]:
chunk_length = min(right_sample, audio_batch.shape[1]) - left_sample
is_last_chunk_batch = chunk_length >= rest_audio_lengths
is_last_chunk = right_sample >= audio_batch.shape[1]
chunk_lengths_batch = torch.where(
is_last_chunk_batch,
rest_audio_lengths,
torch.full_like(rest_audio_lengths, fill_value=chunk_length),
)
buffer.add_audio_batch_(
audio_batch[:, left_sample:right_sample],
audio_lengths=chunk_lengths_batch,
is_last_chunk=is_last_chunk,
is_last_chunk_batch=is_last_chunk_batch,
)
encoder_output, encoder_output_len = asr_model(
input_signal=buffer.samples,
input_signal_length=buffer.context_size_batch.total(),
)
encoder_output = encoder_output.transpose(1, 2) # [B, T, C]
encoder_context = buffer.context_size.subsample(factor=encoder_frame2audio_samples)
encoder_context_batch = buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples)
encoder_output = encoder_output[:, encoder_context.left :]
chunk_batched_hyps, _, state = decoding_computer(
x=encoder_output,
out_len=encoder_context_batch.chunk,
prev_batched_state=state,
)
if current_batched_hyps is None:
current_batched_hyps = chunk_batched_hyps
else:
current_batched_hyps.merge_(chunk_batched_hyps)
rest_audio_lengths -= chunk_lengths_batch
left_sample = right_sample
right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1])
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
for hyp in all_hyps:
hyp.text = asr_model.tokenizer.ids_to_text(hyp.y_sequence.tolist())
written_paths, pred_text_attr_name = write_transcription_custom(
all_hyps,
cfg,
model_name,
filepaths=filepaths,
compute_langs=False,
enable_timestamps=True,
tokenizer=asr_model.tokenizer,
secs_per_step=secs_per_step, # critical: convert step offsets -> seconds
latency_comp=latency_comp, # optional UI compensation (default 0.0)
)
logging.info(f"Wrote {len(written_paths)} transcript files.")
if written_paths:
# Log up to first 5 paths for reference
preview = "\n - ".join(written_paths[:5])
logging.info(f"Sample outputs (up to 5):\n - {preview}")
return cfg
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment