Forked from stephenmcconnachie/speech_to_text_streaming_infer_rnnt_timestamps_individual.py
Created
October 12, 2025 00:47
-
-
Save Mat-KH/8374cadb5e8b8f343671ddf3603813aa to your computer and use it in GitHub Desktop.
speech_to_text_streaming_infer_rnnt_timestamps_individual
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Streaming / buffered RNNT inference with fast word-level timestamps including word text. | |
| Modified from NVIDIA script in Nemo ASR examples: | |
| https://github.com/NVIDIA-NeMo/NeMo/blob/main/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py | |
| Based on Issue: | |
| https://github.com/NVIDIA-NeMo/NeMo/issues/14714 | |
| Key points: | |
| - Maintains streaming throughput (no switch to high-level transcribe()). | |
| - Emits only "pred_text" and "word" timestamps (list of {start, end, word} in seconds). | |
| - Derives word boundaries from tokenizer tokens: | |
| - SentencePiece: tokens containing '▁' are word starts. | |
| - WordPiece: tokens without '##' prefix are word starts. | |
| - Fallback: treat each token as a word. | |
| - Converts encoder-step offsets to seconds using the model's true stride (secs_per_step). | |
| - Does NOT use normalize_timestamp_output() for synthesized word timings (avoids 0.1 s/step assumption). | |
| Modified behavior: | |
| - Always writes one pretty-printed JSON per input audio file, saved next to the audio, | |
| using the same basename with a .json extension (e.g., sample.wav -> sample.json). | |
| - No aggregate manifest (audio.json) is created. | |
| - WER computation removed for simplicity. | |
| """ | |
| import copy | |
| import glob | |
| import json | |
| import os | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Optional, Any, List, Tuple | |
| import lightning.pytorch as pl | |
| import torch | |
| from omegaconf import OmegaConf, open_dict | |
| from torch.utils.data import DataLoader | |
| from tqdm.auto import tqdm | |
| from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel | |
| from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig | |
| from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import ( | |
| GreedyBatchedLabelLoopingComputerBase, | |
| ) | |
| from nemo.collections.asr.parts.utils.manifest_utils import filepath_to_absolute, read_manifest | |
| from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses | |
| from nemo.collections.asr.parts.utils.streaming_utils import ( | |
| AudioBatch, | |
| ContextSize, | |
| SimpleAudioDataset, | |
| StreamingBatchedAudioBuffer, | |
| ) | |
| from nemo.collections.asr.parts.utils.transcribe_utils import ( | |
| setup_model, | |
| ) | |
| from nemo.core.config import hydra_runner | |
| from nemo.utils import logging | |
| def make_divisible_by(num, factor: int) -> int: | |
| return (num // factor) * factor | |
| def _to_serializable(obj: Any) -> Any: | |
| try: | |
| import numpy as np | |
| except Exception: | |
| np = None | |
| if torch.is_tensor(obj): | |
| if obj.ndim == 0: | |
| return obj.item() | |
| return _to_serializable(obj.detach().cpu().tolist()) | |
| if np is not None and isinstance(obj, np.generic): | |
| return obj.item() | |
| if np is not None and isinstance(obj, np.ndarray): | |
| return _to_serializable(obj.tolist()) | |
| if isinstance(obj, dict): | |
| return {k: _to_serializable(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple)): | |
| return [_to_serializable(v) for v in obj] | |
| return obj | |
| def _as_list_of_ints(x: Any) -> Optional[List[int]]: | |
| if x is None: | |
| return None | |
| if torch.is_tensor(x): | |
| x = x.detach().cpu() | |
| if x.ndim == 0: | |
| return [int(x.item())] | |
| return [int(v) for v in x.tolist()] | |
| try: | |
| import numpy as np | |
| except Exception: | |
| np = None | |
| if np is not None and isinstance(x, np.ndarray): | |
| if x.ndim == 0: | |
| return [int(x.item())] | |
| return [int(v) for v in x.tolist()] | |
| if isinstance(x, (list, tuple)): | |
| out = [] | |
| for v in x: | |
| try: | |
| out.append(int(v)) | |
| except Exception: | |
| return None | |
| return out | |
| try: | |
| return [int(x)] | |
| except Exception: | |
| return None | |
| def _group_tokens_into_words(tokens: List[str]) -> List[Tuple[int, int]]: | |
| """ | |
| Return list of (start_idx, end_idx_inclusive) token spans representing words. | |
| Heuristics: | |
| - SentencePiece: token containing '▁' starts a new word. | |
| - WordPiece: token starting with '##' is a continuation; otherwise starts a new word. | |
| - Fallback: each token is a word. | |
| """ | |
| n = len(tokens) | |
| if n == 0: | |
| return [] | |
| has_sp_marker = any('▁' in t for t in tokens) | |
| has_wp_cont = any(t.startswith('##') for t in tokens) | |
| spans: List[Tuple[int, int]] = [] | |
| if has_sp_marker: | |
| start = 0 | |
| for i in range(1, n): | |
| if '▁' in tokens[i]: | |
| spans.append((start, i - 1)) | |
| start = i | |
| spans.append((start, n - 1)) | |
| return spans | |
| if has_wp_cont: | |
| start = 0 | |
| for i in range(1, n): | |
| if not tokens[i].startswith('##'): | |
| spans.append((start, i - 1)) | |
| start = i | |
| spans.append((start, n - 1)) | |
| return spans | |
| # Fallback: every token is its own "word" | |
| spans = [(i, i) for i in range(n)] | |
| return spans | |
| def _tokens_to_word(tokens: List[str], start: int, end: int) -> str: | |
| """ | |
| Build a clean word string from token pieces in [start, end]. | |
| - SentencePiece: drop '▁' markers and concatenate. | |
| - WordPiece: drop '##' prefixes and concatenate. | |
| """ | |
| span = tokens[start : end + 1] | |
| if not span: | |
| return "" | |
| if any('▁' in t for t in span): # SentencePiece style | |
| pieces = [t.replace('▁', '') for t in span] | |
| return "".join(pieces) | |
| # WordPiece style: remove leading '##' | |
| pieces = [t[2:] if t.startswith('##') else t for t in span] | |
| return "".join(pieces) | |
| def _convert_word_offsets_list_to_seconds( | |
| word_list: List[dict], | |
| secs_per_step: float, | |
| latency_comp: float = 0.0, | |
| ) -> List[dict]: | |
| """ | |
| Convert a provided 'word' list that might be in step units | |
| (keys 'start_offset'/'end_offset') into seconds ('start'/'end'). | |
| If entries already have 'start'/'end' in seconds, preserve them. | |
| """ | |
| out: List[dict] = [] | |
| for w in word_list: | |
| if "start" in w and "end" in w: | |
| start_sec = float(w["start"]) | |
| end_sec = float(w["end"]) | |
| if latency_comp: | |
| start_sec = max(0.0, start_sec - latency_comp) | |
| end_sec = max(start_sec, end_sec - latency_comp) | |
| out.append({"start": start_sec, "end": end_sec, "word": w.get("word", w.get("text", ""))}) | |
| elif "start_offset" in w and "end_offset" in w: | |
| s = int(w["start_offset"]) | |
| e = int(w["end_offset"]) | |
| # inclusive -> exclusive end boundary (+1) | |
| start_sec = s * secs_per_step | |
| end_sec = (e + 1) * secs_per_step | |
| if latency_comp: | |
| start_sec = max(0.0, start_sec - latency_comp) | |
| end_sec = max(start_sec, end_sec - latency_comp) | |
| out.append({"start": float(start_sec), "end": float(end_sec), "word": w.get("word", w.get("text", ""))}) | |
| else: | |
| # Unknown structure; best-effort pass-through | |
| out.append({**w}) | |
| return out | |
| def write_transcription_custom( | |
| transcriptions, | |
| cfg, | |
| model_name: str, | |
| filepaths: Optional[List[str]] = None, | |
| compute_langs: bool = False, | |
| enable_timestamps: bool = False, | |
| tokenizer=None, | |
| secs_per_step: Optional[float] = None, | |
| latency_comp: float = 0.0, | |
| ): | |
| """ | |
| Minimal, fast writer (per-file JSON): | |
| - Writes pred_text and word timestamps ONLY. | |
| - For timestamps: | |
| - If hyp.timestamp contains 'word': write seconds directly if available; | |
| if offsets, convert to seconds using secs_per_step. | |
| - Else: synthesize words from tokenizer tokens + timestep list and convert to seconds. | |
| - Always writes one pretty-printed JSON per audio file, next to the audio, with .json extension. | |
| - Respects cfg.overwrite_transcripts per file. | |
| """ | |
| if cfg.append_pred: | |
| logging.info('Per-file transcripts will include an appended prediction field name.') | |
| pred_by_model_name = cfg.pred_name_postfix if cfg.pred_name_postfix is not None else model_name | |
| pred_text_attr_name = 'pred_text_' + pred_by_model_name | |
| else: | |
| pred_text_attr_name = 'pred_text' | |
| # Build pairs of (manifest_item, audio_fp) to align with transcriptions order | |
| if cfg.audio_dir is not None: | |
| assert filepaths is not None, "filepaths must be provided when using audio_dir" | |
| pairs = [(None, fp) for fp in filepaths] | |
| else: | |
| pairs = [] | |
| with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr: | |
| for line in fr: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| pairs.append((json.loads(line), None)) | |
| # Determine structure of transcriptions | |
| if isinstance(transcriptions[0], str): | |
| best_hyps = transcriptions | |
| return_hypotheses = False | |
| beams = None | |
| elif hasattr(transcriptions[0], "text"): | |
| best_hyps = transcriptions | |
| return_hypotheses = True | |
| beams = None | |
| elif isinstance(transcriptions[0], list) and hasattr(transcriptions[0][0], "text"): | |
| best_hyps, beams = [], [] | |
| for hyps in transcriptions: | |
| best_hyps.append(hyps[0]) | |
| if not cfg.decoding.beam.return_best_hypothesis: | |
| beam = [] | |
| for hyp in hyps: | |
| score = hyp.score.numpy().item() if isinstance(hyp.score, torch.Tensor) else hyp.score | |
| beam.append((hyp.text, score)) | |
| beams.append(beam) | |
| return_hypotheses = True | |
| else: | |
| raise TypeError("Unexpected transcription structure") | |
| written_paths: List[str] = [] | |
| for idx, (manifest_item, audio_fp) in enumerate(pairs): | |
| # Construct output JSON object for this item | |
| if manifest_item is None: | |
| # audio_dir mode | |
| if not return_hypotheses: | |
| item = {'audio_filepath': audio_fp, pred_text_attr_name: best_hyps[idx]} | |
| else: | |
| hyp = best_hyps[idx] | |
| item = {'audio_filepath': audio_fp, pred_text_attr_name: hyp.text} | |
| if enable_timestamps: | |
| timestamp_data = getattr(hyp, "timestamp", None) | |
| # Case 1: model already provides word-level dict | |
| if isinstance(timestamp_data, dict) and "word" in timestamp_data: | |
| word_list = timestamp_data["word"] | |
| if secs_per_step is not None: | |
| try: | |
| item['word'] = _to_serializable( | |
| _convert_word_offsets_list_to_seconds(word_list, secs_per_step, latency_comp) | |
| ) | |
| except Exception as e: | |
| logging.warning(f"Failed converting provided 'word' offsets to seconds: {e}. Writing raw.") | |
| item['word'] = _to_serializable(word_list) | |
| else: | |
| item['word'] = _to_serializable(word_list) | |
| else: | |
| # Case 2: synthesize from tokens + timestep | |
| token_ids = hyp.y_sequence.tolist() if hasattr(hyp, "y_sequence") else [] | |
| if tokenizer is not None and token_ids: | |
| try: | |
| tokens = tokenizer.ids_to_tokens(token_ids) | |
| except Exception: | |
| tokens = [] | |
| else: | |
| tokens = [] | |
| timestep_list = None | |
| if isinstance(timestamp_data, dict): | |
| timestep_list = _as_list_of_ints(timestamp_data.get('timestep', None)) | |
| else: | |
| timestep_list = _as_list_of_ints(timestamp_data) | |
| if tokens and timestep_list: | |
| if len(timestep_list) != len(tokens): | |
| m = min(len(timestep_list), len(tokens)) | |
| tokens = tokens[:m] | |
| timestep_list = timestep_list[:m] | |
| word_spans = _group_tokens_into_words(tokens) | |
| word_entries = [] | |
| for (s, e) in word_spans: | |
| if 0 <= s < len(timestep_list) and 0 <= e < len(timestep_list) and s <= e: | |
| word_text = _tokens_to_word(tokens, s, e) | |
| if word_text == "": | |
| continue | |
| if secs_per_step is None: | |
| # fallback to raw offsets if stride unknown | |
| word_entries.append({ | |
| "start_offset": int(timestep_list[s]), | |
| "end_offset": int(timestep_list[e]), | |
| "word": word_text | |
| }) | |
| else: | |
| start_sec = float(timestep_list[s]) * secs_per_step | |
| end_sec = float(timestep_list[e] + 1) * secs_per_step # inclusive->exclusive | |
| if latency_comp: | |
| start_sec = max(0.0, start_sec - latency_comp) | |
| end_sec = max(start_sec, end_sec - latency_comp) | |
| word_entries.append({ | |
| "start": start_sec, | |
| "end": end_sec, | |
| "word": word_text | |
| }) | |
| if word_entries: | |
| item['word'] = _to_serializable(word_entries) | |
| out_path = Path(audio_fp).with_suffix(".json") | |
| else: | |
| # dataset_manifest mode | |
| if not return_hypotheses: | |
| item = dict(manifest_item) | |
| item[pred_text_attr_name] = best_hyps[idx] | |
| else: | |
| hyp = best_hyps[idx] | |
| item = dict(manifest_item) | |
| item[pred_text_attr_name] = hyp.text | |
| if enable_timestamps: | |
| timestamp_data = getattr(hyp, "timestamp", None) | |
| if isinstance(timestamp_data, dict) and "word" in timestamp_data: | |
| word_list = timestamp_data["word"] | |
| if secs_per_step is not None: | |
| try: | |
| item['word'] = _to_serializable( | |
| _convert_word_offsets_list_to_seconds(word_list, secs_per_step, latency_comp) | |
| ) | |
| except Exception as e: | |
| logging.warning(f"Failed converting provided 'word' offsets to seconds: {e}. Writing raw.") | |
| item['word'] = _to_serializable(word_list) | |
| else: | |
| item['word'] = _to_serializable(word_list) | |
| else: | |
| token_ids = hyp.y_sequence.tolist() if hasattr(hyp, "y_sequence") else [] | |
| if tokenizer is not None and token_ids: | |
| try: | |
| tokens = tokenizer.ids_to_tokens(token_ids) | |
| except Exception: | |
| tokens = [] | |
| else: | |
| tokens = [] | |
| timestep_list = None | |
| if isinstance(timestamp_data, dict): | |
| timestep_list = _as_list_of_ints(timestamp_data.get('timestep', None)) | |
| else: | |
| timestep_list = _as_list_of_ints(timestamp_data) | |
| if tokens and timestep_list: | |
| if len(timestep_list) != len(tokens): | |
| m = min(len(timestep_list), len(tokens)) | |
| tokens = tokens[:m] | |
| timestep_list = timestep_list[:m] | |
| word_spans = _group_tokens_into_words(tokens) | |
| word_entries = [] | |
| for (s, e) in word_spans: | |
| if 0 <= s < len(timestep_list) and 0 <= e < len(timestep_list) and s <= e: | |
| word_text = _tokens_to_word(tokens, s, e) | |
| if word_text == "": | |
| continue | |
| if secs_per_step is None: | |
| word_entries.append({ | |
| "start_offset": int(timestep_list[s]), | |
| "end_offset": int(timestep_list[e]), | |
| "word": word_text | |
| }) | |
| else: | |
| start_sec = float(timestep_list[s]) * secs_per_step | |
| end_sec = float(timestep_list[e] + 1) * secs_per_step # inclusive->exclusive | |
| if latency_comp: | |
| start_sec = max(0.0, start_sec - latency_comp) | |
| end_sec = max(start_sec, end_sec - latency_comp) | |
| word_entries.append({ | |
| "start": start_sec, | |
| "end": end_sec, | |
| "word": word_text | |
| }) | |
| if word_entries: | |
| item['word'] = _to_serializable(word_entries) | |
| out_path = Path(item["audio_filepath"]).with_suffix(".json") | |
| # Ensure parent exists and write pretty JSON (indent=2) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| if (not cfg.overwrite_transcripts) and out_path.exists(): | |
| logging.info(f"Skipping existing transcript (overwrite_transcripts=False): {out_path}") | |
| else: | |
| with open(out_path, 'w', encoding='utf-8', newline='\n') as fout: | |
| json.dump(item, fout, ensure_ascii=False, indent=2) | |
| fout.write("\n") | |
| written_paths.append(str(out_path)) | |
| return written_paths, pred_text_attr_name | |
| @dataclass | |
| class TranscriptionConfig: | |
| model_path: Optional[str] = None | |
| pretrained_name: Optional[str] = None | |
| audio_dir: Optional[str] = None | |
| dataset_manifest: Optional[str] = None | |
| # Output manifest fields removed; per-file JSON only | |
| batch_size: int = 32 | |
| num_workers: int = 0 | |
| append_pred: bool = False | |
| pred_name_postfix: Optional[str] = None | |
| random_seed: Optional[int] = None | |
| chunk_secs: float = 2 | |
| left_context_secs: float = 10.0 | |
| right_context_secs: float = 2 | |
| cuda: Optional[int] = None | |
| allow_mps: bool = True | |
| compute_dtype: Optional[str] = None | |
| matmul_precision: str = "high" | |
| audio_type: str = "wav" | |
| overwrite_transcripts: bool = True | |
| decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig) | |
| @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) | |
| def main(cfg: TranscriptionConfig) -> TranscriptionConfig: | |
| logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | |
| torch.set_grad_enabled(False) | |
| torch.set_float32_matmul_precision(cfg.matmul_precision) | |
| cfg = OmegaConf.structured(cfg) | |
| if cfg.random_seed: | |
| pl.seed_everything(cfg.random_seed) | |
| if cfg.model_path is None and cfg.pretrained_name is None: | |
| raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") | |
| if cfg.audio_dir is None and cfg.dataset_manifest is None: | |
| raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") | |
| filepaths = None | |
| manifest = cfg.dataset_manifest | |
| if cfg.audio_dir is not None: | |
| filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) | |
| manifest = None | |
| # setup device | |
| if cfg.cuda is None: | |
| if torch.cuda.is_available(): | |
| map_location = torch.device('cuda:0') | |
| elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| logging.warning( | |
| "MPS device support is experimental. Set PYTORCH_ENABLE_MPS_FALLBACK=1 to avoid failures." | |
| ) | |
| map_location = torch.device('mps') | |
| else: | |
| map_location = torch.device('cpu') | |
| elif cfg.cuda < 0: | |
| map_location = torch.device('cpu') | |
| else: | |
| map_location = torch.device(f'cuda:{cfg.cuda}') | |
| if cfg.compute_dtype is None: | |
| can_use_bfloat16 = map_location.type == "cuda" and torch.cuda.is_bf16_supported() | |
| compute_dtype = torch.bfloat16 if can_use_bfloat16 else torch.float32 | |
| else: | |
| assert cfg.compute_dtype in {"float32", "bfloat16", "float16"} | |
| compute_dtype = getattr(torch, cfg.compute_dtype) | |
| logging.info(f"Inference device: {map_location}, dtype: {compute_dtype}") | |
| asr_model, model_name = setup_model(cfg, map_location) | |
| model_cfg = copy.deepcopy(asr_model._cfg) | |
| OmegaConf.set_struct(model_cfg.preprocessor, False) | |
| model_cfg.preprocessor.dither = 0.0 | |
| model_cfg.preprocessor.pad_to = 0 | |
| if model_cfg.preprocessor.normalize != "per_feature": | |
| logging.error("Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently") | |
| OmegaConf.set_struct(model_cfg.preprocessor, True) | |
| # No aggregate output file is computed or used. | |
| asr_model.freeze() | |
| asr_model = asr_model.to(asr_model.device) | |
| asr_model.to(compute_dtype) | |
| with open_dict(cfg.decoding): | |
| if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True: | |
| raise NotImplementedError("This script supports only `greedy_batch` with Label-Looping algorithm") | |
| cfg.decoding.preserve_alignments = True | |
| cfg.decoding.fused_batch_size = -1 | |
| cfg.decoding.beam.return_best_hypothesis = True | |
| # Intentionally NOT setting compute_timestamps=True; we synthesize word times ourselves. | |
| if hasattr(asr_model, 'change_decoding_strategy'): | |
| if not isinstance(asr_model, EncDecRNNTModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel): | |
| raise ValueError("This script supports RNNT model and hybrid RNNT-CTC model with RNNT decoder") | |
| if isinstance(asr_model, EncDecRNNTModel): | |
| asr_model.change_decoding_strategy(cfg.decoding) | |
| if hasattr(asr_model, 'cur_decoder'): | |
| asr_model.change_decoding_strategy(cfg.decoding, decoder_type='rnnt') | |
| if manifest is not None: | |
| records = read_manifest(manifest) | |
| manifest_dir = Path(manifest).parent.absolute() | |
| for record in records: | |
| record["audio_filepath"] = str(filepath_to_absolute(record["audio_filepath"], manifest_dir)) | |
| else: | |
| assert filepaths is not None | |
| records = [{"audio_filepath": audio_file} for audio_file in filepaths] | |
| asr_model.preprocessor.featurizer.dither = 0.0 | |
| asr_model.preprocessor.featurizer.pad_to = 0 | |
| asr_model.eval() | |
| decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer | |
| audio_sample_rate = model_cfg.preprocessor['sample_rate'] | |
| feature_stride_sec = model_cfg.preprocessor['window_stride'] | |
| features_per_sec = 1.0 / feature_stride_sec | |
| encoder_subsampling_factor = asr_model.encoder.subsampling_factor | |
| features_frame2audio_samples = make_divisible_by( | |
| int(audio_sample_rate * feature_stride_sec), factor=encoder_subsampling_factor | |
| ) | |
| encoder_frame2audio_samples = features_frame2audio_samples * encoder_subsampling_factor | |
| # Accurate seconds per encoder step (accounts for divisibility correction) | |
| secs_per_step = encoder_frame2audio_samples / audio_sample_rate | |
| # Optional UI latency compensation (seconds). Set to cfg.right_context_secs if desired. | |
| latency_comp = 0.0 | |
| context_encoder_frames = ContextSize( | |
| left=int(cfg.left_context_secs * features_per_sec / encoder_subsampling_factor), | |
| chunk=int(cfg.chunk_secs * features_per_sec / encoder_subsampling_factor), | |
| right=int(cfg.right_context_secs * features_per_sec / encoder_subsampling_factor), | |
| ) | |
| context_samples = ContextSize( | |
| left=context_encoder_frames.left * encoder_subsampling_factor * features_frame2audio_samples, | |
| chunk=context_encoder_frames.chunk * encoder_subsampling_factor * features_frame2audio_samples, | |
| right=context_encoder_frames.right * encoder_subsampling_factor * features_frame2audio_samples, | |
| ) | |
| logging.info( | |
| "Corrected contexts (sec): " | |
| f"Left {context_samples.left / audio_sample_rate:.2f}, " | |
| f"Chunk {context_samples.chunk / audio_sample_rate:.2f}, " | |
| f"Right {context_samples.right / audio_sample_rate:.2f}" | |
| ) | |
| logging.info( | |
| f"Corrected contexts (subsampled encoder frames): Left {context_encoder_frames.left} - " | |
| f"Chunk {context_encoder_frames.chunk} - Right {context_encoder_frames.right}" | |
| ) | |
| logging.info( | |
| f"Corrected contexts (in audio samples): Left {context_samples.left} - " | |
| f"Chunk {context_samples.chunk} - Right {context_samples.right}" | |
| ) | |
| latency_secs = (context_samples.chunk + context_samples.right) / audio_sample_rate | |
| logging.info(f"Theoretical latency: {latency_secs:.2f} seconds") | |
| logging.info(f"secs_per_step (encoder): {secs_per_step:.6f} s") | |
| audio_dataset = SimpleAudioDataset( | |
| audio_filenames=[record["audio_filepath"] for record in records], sample_rate=audio_sample_rate | |
| ) | |
| audio_dataloader = DataLoader( | |
| dataset=audio_dataset, | |
| batch_size=cfg.batch_size, | |
| shuffle=False, | |
| num_workers=cfg.num_workers, | |
| collate_fn=AudioBatch.collate_fn, | |
| drop_last=False, | |
| in_order=True, | |
| ) | |
| with torch.no_grad(), torch.inference_mode(): | |
| all_hyps = [] | |
| for audio_data in tqdm(audio_dataloader): | |
| audio_batch = audio_data.audio_signals.to(device=map_location) | |
| audio_batch_lengths = audio_data.audio_signal_lengths.to(device=map_location) | |
| batch_size = audio_batch.shape[0] | |
| device = audio_batch.device | |
| current_batched_hyps: BatchedHyps | None = None | |
| state = None | |
| left_sample = 0 | |
| right_sample = min(context_samples.chunk + context_samples.right, audio_batch.shape[1]) | |
| buffer = StreamingBatchedAudioBuffer( | |
| batch_size=batch_size, | |
| context_samples=context_samples, | |
| dtype=audio_batch.dtype, | |
| device=device, | |
| ) | |
| rest_audio_lengths = audio_batch_lengths.clone() | |
| while left_sample < audio_batch.shape[1]: | |
| chunk_length = min(right_sample, audio_batch.shape[1]) - left_sample | |
| is_last_chunk_batch = chunk_length >= rest_audio_lengths | |
| is_last_chunk = right_sample >= audio_batch.shape[1] | |
| chunk_lengths_batch = torch.where( | |
| is_last_chunk_batch, | |
| rest_audio_lengths, | |
| torch.full_like(rest_audio_lengths, fill_value=chunk_length), | |
| ) | |
| buffer.add_audio_batch_( | |
| audio_batch[:, left_sample:right_sample], | |
| audio_lengths=chunk_lengths_batch, | |
| is_last_chunk=is_last_chunk, | |
| is_last_chunk_batch=is_last_chunk_batch, | |
| ) | |
| encoder_output, encoder_output_len = asr_model( | |
| input_signal=buffer.samples, | |
| input_signal_length=buffer.context_size_batch.total(), | |
| ) | |
| encoder_output = encoder_output.transpose(1, 2) # [B, T, C] | |
| encoder_context = buffer.context_size.subsample(factor=encoder_frame2audio_samples) | |
| encoder_context_batch = buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples) | |
| encoder_output = encoder_output[:, encoder_context.left :] | |
| chunk_batched_hyps, _, state = decoding_computer( | |
| x=encoder_output, | |
| out_len=encoder_context_batch.chunk, | |
| prev_batched_state=state, | |
| ) | |
| if current_batched_hyps is None: | |
| current_batched_hyps = chunk_batched_hyps | |
| else: | |
| current_batched_hyps.merge_(chunk_batched_hyps) | |
| rest_audio_lengths -= chunk_lengths_batch | |
| left_sample = right_sample | |
| right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) | |
| all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size)) | |
| for hyp in all_hyps: | |
| hyp.text = asr_model.tokenizer.ids_to_text(hyp.y_sequence.tolist()) | |
| written_paths, pred_text_attr_name = write_transcription_custom( | |
| all_hyps, | |
| cfg, | |
| model_name, | |
| filepaths=filepaths, | |
| compute_langs=False, | |
| enable_timestamps=True, | |
| tokenizer=asr_model.tokenizer, | |
| secs_per_step=secs_per_step, # critical: convert step offsets -> seconds | |
| latency_comp=latency_comp, # optional UI compensation (default 0.0) | |
| ) | |
| logging.info(f"Wrote {len(written_paths)} transcript files.") | |
| if written_paths: | |
| # Log up to first 5 paths for reference | |
| preview = "\n - ".join(written_paths[:5]) | |
| logging.info(f"Sample outputs (up to 5):\n - {preview}") | |
| return cfg | |
| if __name__ == '__main__': | |
| main() # noqa pylint: disable=no-value-for-parameter |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment