Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Mat-KH/8374cadb5e8b8f343671ddf3603813aa to your computer and use it in GitHub Desktop.
Save Mat-KH/8374cadb5e8b8f343671ddf3603813aa to your computer and use it in GitHub Desktop.

Revisions

  1. @stephenmcconnachie stephenmcconnachie revised this gist Sep 23, 2025. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions speech_to_text_streaming_infer_rnnt_timestamps_individual.py
    Original file line number Diff line number Diff line change
    @@ -234,9 +234,9 @@ def write_transcription_custom(
    Minimal, fast writer (per-file JSON):
    - Writes pred_text and word timestamps ONLY.
    - For timestamps:
    * If hyp.timestamp contains 'word': write seconds directly if available;
    - If hyp.timestamp contains 'word': write seconds directly if available;
    if offsets, convert to seconds using secs_per_step.
    * Else: synthesize words from tokenizer tokens + timestep list and convert to seconds.
    - Else: synthesize words from tokenizer tokens + timestep list and convert to seconds.
    - Always writes one pretty-printed JSON per audio file, next to the audio, with .json extension.
    - Respects cfg.overwrite_transcripts per file.
    """
  2. @stephenmcconnachie stephenmcconnachie created this gist Sep 23, 2025.
    723 changes: 723 additions & 0 deletions speech_to_text_streaming_infer_rnnt_timestamps_individual.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,723 @@
    # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.

    """
    Streaming / buffered RNNT inference with fast word-level timestamps including word text.
    Modified from NVIDIA script in Nemo ASR examples:
    https://github.com/NVIDIA-NeMo/NeMo/blob/main/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py
    Based on Issue:
    https://github.com/NVIDIA-NeMo/NeMo/issues/14714
    Key points:
    - Maintains streaming throughput (no switch to high-level transcribe()).
    - Emits only "pred_text" and "word" timestamps (list of {start, end, word} in seconds).
    - Derives word boundaries from tokenizer tokens:
    - SentencePiece: tokens containing '▁' are word starts.
    - WordPiece: tokens without '##' prefix are word starts.
    - Fallback: treat each token as a word.
    - Converts encoder-step offsets to seconds using the model's true stride (secs_per_step).
    - Does NOT use normalize_timestamp_output() for synthesized word timings (avoids 0.1 s/step assumption).
    Modified behavior:
    - Always writes one pretty-printed JSON per input audio file, saved next to the audio,
    using the same basename with a .json extension (e.g., sample.wav -> sample.json).
    - No aggregate manifest (audio.json) is created.
    - WER computation removed for simplicity.
    """

    import copy
    import glob
    import json
    import os
    from dataclasses import dataclass, field
    from pathlib import Path
    from typing import Optional, Any, List, Tuple

    import lightning.pytorch as pl
    import torch
    from omegaconf import OmegaConf, open_dict
    from torch.utils.data import DataLoader
    from tqdm.auto import tqdm

    from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
    from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
    from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
    GreedyBatchedLabelLoopingComputerBase,
    )
    from nemo.collections.asr.parts.utils.manifest_utils import filepath_to_absolute, read_manifest
    from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses
    from nemo.collections.asr.parts.utils.streaming_utils import (
    AudioBatch,
    ContextSize,
    SimpleAudioDataset,
    StreamingBatchedAudioBuffer,
    )
    from nemo.collections.asr.parts.utils.transcribe_utils import (
    setup_model,
    )
    from nemo.core.config import hydra_runner
    from nemo.utils import logging


    def make_divisible_by(num, factor: int) -> int:
    return (num // factor) * factor


    def _to_serializable(obj: Any) -> Any:
    try:
    import numpy as np
    except Exception:
    np = None

    if torch.is_tensor(obj):
    if obj.ndim == 0:
    return obj.item()
    return _to_serializable(obj.detach().cpu().tolist())
    if np is not None and isinstance(obj, np.generic):
    return obj.item()
    if np is not None and isinstance(obj, np.ndarray):
    return _to_serializable(obj.tolist())
    if isinstance(obj, dict):
    return {k: _to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
    return [_to_serializable(v) for v in obj]
    return obj


    def _as_list_of_ints(x: Any) -> Optional[List[int]]:
    if x is None:
    return None
    if torch.is_tensor(x):
    x = x.detach().cpu()
    if x.ndim == 0:
    return [int(x.item())]
    return [int(v) for v in x.tolist()]
    try:
    import numpy as np
    except Exception:
    np = None
    if np is not None and isinstance(x, np.ndarray):
    if x.ndim == 0:
    return [int(x.item())]
    return [int(v) for v in x.tolist()]
    if isinstance(x, (list, tuple)):
    out = []
    for v in x:
    try:
    out.append(int(v))
    except Exception:
    return None
    return out
    try:
    return [int(x)]
    except Exception:
    return None


    def _group_tokens_into_words(tokens: List[str]) -> List[Tuple[int, int]]:
    """
    Return list of (start_idx, end_idx_inclusive) token spans representing words.
    Heuristics:
    - SentencePiece: token containing '▁' starts a new word.
    - WordPiece: token starting with '##' is a continuation; otherwise starts a new word.
    - Fallback: each token is a word.
    """
    n = len(tokens)
    if n == 0:
    return []

    has_sp_marker = any('▁' in t for t in tokens)
    has_wp_cont = any(t.startswith('##') for t in tokens)

    spans: List[Tuple[int, int]] = []

    if has_sp_marker:
    start = 0
    for i in range(1, n):
    if '▁' in tokens[i]:
    spans.append((start, i - 1))
    start = i
    spans.append((start, n - 1))
    return spans

    if has_wp_cont:
    start = 0
    for i in range(1, n):
    if not tokens[i].startswith('##'):
    spans.append((start, i - 1))
    start = i
    spans.append((start, n - 1))
    return spans

    # Fallback: every token is its own "word"
    spans = [(i, i) for i in range(n)]
    return spans


    def _tokens_to_word(tokens: List[str], start: int, end: int) -> str:
    """
    Build a clean word string from token pieces in [start, end].
    - SentencePiece: drop '▁' markers and concatenate.
    - WordPiece: drop '##' prefixes and concatenate.
    """
    span = tokens[start : end + 1]
    if not span:
    return ""
    if any('▁' in t for t in span): # SentencePiece style
    pieces = [t.replace('▁', '') for t in span]
    return "".join(pieces)
    # WordPiece style: remove leading '##'
    pieces = [t[2:] if t.startswith('##') else t for t in span]
    return "".join(pieces)


    def _convert_word_offsets_list_to_seconds(
    word_list: List[dict],
    secs_per_step: float,
    latency_comp: float = 0.0,
    ) -> List[dict]:
    """
    Convert a provided 'word' list that might be in step units
    (keys 'start_offset'/'end_offset') into seconds ('start'/'end').
    If entries already have 'start'/'end' in seconds, preserve them.
    """
    out: List[dict] = []
    for w in word_list:
    if "start" in w and "end" in w:
    start_sec = float(w["start"])
    end_sec = float(w["end"])
    if latency_comp:
    start_sec = max(0.0, start_sec - latency_comp)
    end_sec = max(start_sec, end_sec - latency_comp)
    out.append({"start": start_sec, "end": end_sec, "word": w.get("word", w.get("text", ""))})
    elif "start_offset" in w and "end_offset" in w:
    s = int(w["start_offset"])
    e = int(w["end_offset"])
    # inclusive -> exclusive end boundary (+1)
    start_sec = s * secs_per_step
    end_sec = (e + 1) * secs_per_step
    if latency_comp:
    start_sec = max(0.0, start_sec - latency_comp)
    end_sec = max(start_sec, end_sec - latency_comp)
    out.append({"start": float(start_sec), "end": float(end_sec), "word": w.get("word", w.get("text", ""))})
    else:
    # Unknown structure; best-effort pass-through
    out.append({**w})
    return out


    def write_transcription_custom(
    transcriptions,
    cfg,
    model_name: str,
    filepaths: Optional[List[str]] = None,
    compute_langs: bool = False,
    enable_timestamps: bool = False,
    tokenizer=None,
    secs_per_step: Optional[float] = None,
    latency_comp: float = 0.0,
    ):
    """
    Minimal, fast writer (per-file JSON):
    - Writes pred_text and word timestamps ONLY.
    - For timestamps:
    * If hyp.timestamp contains 'word': write seconds directly if available;
    if offsets, convert to seconds using secs_per_step.
    * Else: synthesize words from tokenizer tokens + timestep list and convert to seconds.
    - Always writes one pretty-printed JSON per audio file, next to the audio, with .json extension.
    - Respects cfg.overwrite_transcripts per file.
    """
    if cfg.append_pred:
    logging.info('Per-file transcripts will include an appended prediction field name.')
    pred_by_model_name = cfg.pred_name_postfix if cfg.pred_name_postfix is not None else model_name
    pred_text_attr_name = 'pred_text_' + pred_by_model_name
    else:
    pred_text_attr_name = 'pred_text'

    # Build pairs of (manifest_item, audio_fp) to align with transcriptions order
    if cfg.audio_dir is not None:
    assert filepaths is not None, "filepaths must be provided when using audio_dir"
    pairs = [(None, fp) for fp in filepaths]
    else:
    pairs = []
    with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr:
    for line in fr:
    line = line.strip()
    if not line:
    continue
    pairs.append((json.loads(line), None))

    # Determine structure of transcriptions
    if isinstance(transcriptions[0], str):
    best_hyps = transcriptions
    return_hypotheses = False
    beams = None
    elif hasattr(transcriptions[0], "text"):
    best_hyps = transcriptions
    return_hypotheses = True
    beams = None
    elif isinstance(transcriptions[0], list) and hasattr(transcriptions[0][0], "text"):
    best_hyps, beams = [], []
    for hyps in transcriptions:
    best_hyps.append(hyps[0])
    if not cfg.decoding.beam.return_best_hypothesis:
    beam = []
    for hyp in hyps:
    score = hyp.score.numpy().item() if isinstance(hyp.score, torch.Tensor) else hyp.score
    beam.append((hyp.text, score))
    beams.append(beam)
    return_hypotheses = True
    else:
    raise TypeError("Unexpected transcription structure")

    written_paths: List[str] = []

    for idx, (manifest_item, audio_fp) in enumerate(pairs):
    # Construct output JSON object for this item
    if manifest_item is None:
    # audio_dir mode
    if not return_hypotheses:
    item = {'audio_filepath': audio_fp, pred_text_attr_name: best_hyps[idx]}
    else:
    hyp = best_hyps[idx]
    item = {'audio_filepath': audio_fp, pred_text_attr_name: hyp.text}

    if enable_timestamps:
    timestamp_data = getattr(hyp, "timestamp", None)

    # Case 1: model already provides word-level dict
    if isinstance(timestamp_data, dict) and "word" in timestamp_data:
    word_list = timestamp_data["word"]
    if secs_per_step is not None:
    try:
    item['word'] = _to_serializable(
    _convert_word_offsets_list_to_seconds(word_list, secs_per_step, latency_comp)
    )
    except Exception as e:
    logging.warning(f"Failed converting provided 'word' offsets to seconds: {e}. Writing raw.")
    item['word'] = _to_serializable(word_list)
    else:
    item['word'] = _to_serializable(word_list)
    else:
    # Case 2: synthesize from tokens + timestep
    token_ids = hyp.y_sequence.tolist() if hasattr(hyp, "y_sequence") else []
    if tokenizer is not None and token_ids:
    try:
    tokens = tokenizer.ids_to_tokens(token_ids)
    except Exception:
    tokens = []
    else:
    tokens = []

    timestep_list = None
    if isinstance(timestamp_data, dict):
    timestep_list = _as_list_of_ints(timestamp_data.get('timestep', None))
    else:
    timestep_list = _as_list_of_ints(timestamp_data)

    if tokens and timestep_list:
    if len(timestep_list) != len(tokens):
    m = min(len(timestep_list), len(tokens))
    tokens = tokens[:m]
    timestep_list = timestep_list[:m]

    word_spans = _group_tokens_into_words(tokens)
    word_entries = []
    for (s, e) in word_spans:
    if 0 <= s < len(timestep_list) and 0 <= e < len(timestep_list) and s <= e:
    word_text = _tokens_to_word(tokens, s, e)
    if word_text == "":
    continue
    if secs_per_step is None:
    # fallback to raw offsets if stride unknown
    word_entries.append({
    "start_offset": int(timestep_list[s]),
    "end_offset": int(timestep_list[e]),
    "word": word_text
    })
    else:
    start_sec = float(timestep_list[s]) * secs_per_step
    end_sec = float(timestep_list[e] + 1) * secs_per_step # inclusive->exclusive
    if latency_comp:
    start_sec = max(0.0, start_sec - latency_comp)
    end_sec = max(start_sec, end_sec - latency_comp)
    word_entries.append({
    "start": start_sec,
    "end": end_sec,
    "word": word_text
    })
    if word_entries:
    item['word'] = _to_serializable(word_entries)

    out_path = Path(audio_fp).with_suffix(".json")

    else:
    # dataset_manifest mode
    if not return_hypotheses:
    item = dict(manifest_item)
    item[pred_text_attr_name] = best_hyps[idx]
    else:
    hyp = best_hyps[idx]
    item = dict(manifest_item)
    item[pred_text_attr_name] = hyp.text

    if enable_timestamps:
    timestamp_data = getattr(hyp, "timestamp", None)

    if isinstance(timestamp_data, dict) and "word" in timestamp_data:
    word_list = timestamp_data["word"]
    if secs_per_step is not None:
    try:
    item['word'] = _to_serializable(
    _convert_word_offsets_list_to_seconds(word_list, secs_per_step, latency_comp)
    )
    except Exception as e:
    logging.warning(f"Failed converting provided 'word' offsets to seconds: {e}. Writing raw.")
    item['word'] = _to_serializable(word_list)
    else:
    item['word'] = _to_serializable(word_list)
    else:
    token_ids = hyp.y_sequence.tolist() if hasattr(hyp, "y_sequence") else []
    if tokenizer is not None and token_ids:
    try:
    tokens = tokenizer.ids_to_tokens(token_ids)
    except Exception:
    tokens = []
    else:
    tokens = []

    timestep_list = None
    if isinstance(timestamp_data, dict):
    timestep_list = _as_list_of_ints(timestamp_data.get('timestep', None))
    else:
    timestep_list = _as_list_of_ints(timestamp_data)

    if tokens and timestep_list:
    if len(timestep_list) != len(tokens):
    m = min(len(timestep_list), len(tokens))
    tokens = tokens[:m]
    timestep_list = timestep_list[:m]

    word_spans = _group_tokens_into_words(tokens)
    word_entries = []
    for (s, e) in word_spans:
    if 0 <= s < len(timestep_list) and 0 <= e < len(timestep_list) and s <= e:
    word_text = _tokens_to_word(tokens, s, e)
    if word_text == "":
    continue
    if secs_per_step is None:
    word_entries.append({
    "start_offset": int(timestep_list[s]),
    "end_offset": int(timestep_list[e]),
    "word": word_text
    })
    else:
    start_sec = float(timestep_list[s]) * secs_per_step
    end_sec = float(timestep_list[e] + 1) * secs_per_step # inclusive->exclusive
    if latency_comp:
    start_sec = max(0.0, start_sec - latency_comp)
    end_sec = max(start_sec, end_sec - latency_comp)
    word_entries.append({
    "start": start_sec,
    "end": end_sec,
    "word": word_text
    })
    if word_entries:
    item['word'] = _to_serializable(word_entries)

    out_path = Path(item["audio_filepath"]).with_suffix(".json")

    # Ensure parent exists and write pretty JSON (indent=2)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    if (not cfg.overwrite_transcripts) and out_path.exists():
    logging.info(f"Skipping existing transcript (overwrite_transcripts=False): {out_path}")
    else:
    with open(out_path, 'w', encoding='utf-8', newline='\n') as fout:
    json.dump(item, fout, ensure_ascii=False, indent=2)
    fout.write("\n")
    written_paths.append(str(out_path))

    return written_paths, pred_text_attr_name


    @dataclass
    class TranscriptionConfig:
    model_path: Optional[str] = None
    pretrained_name: Optional[str] = None
    audio_dir: Optional[str] = None
    dataset_manifest: Optional[str] = None

    # Output manifest fields removed; per-file JSON only
    batch_size: int = 32
    num_workers: int = 0
    append_pred: bool = False
    pred_name_postfix: Optional[str] = None
    random_seed: Optional[int] = None

    chunk_secs: float = 2
    left_context_secs: float = 10.0
    right_context_secs: float = 2

    cuda: Optional[int] = None
    allow_mps: bool = True
    compute_dtype: Optional[str] = None
    matmul_precision: str = "high"
    audio_type: str = "wav"

    overwrite_transcripts: bool = True

    decoding: RNNTDecodingConfig = field(default_factory=RNNTDecodingConfig)


    @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
    def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
    torch.set_grad_enabled(False)
    torch.set_float32_matmul_precision(cfg.matmul_precision)

    cfg = OmegaConf.structured(cfg)

    if cfg.random_seed:
    pl.seed_everything(cfg.random_seed)

    if cfg.model_path is None and cfg.pretrained_name is None:
    raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
    if cfg.audio_dir is None and cfg.dataset_manifest is None:
    raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")

    filepaths = None
    manifest = cfg.dataset_manifest
    if cfg.audio_dir is not None:
    filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
    manifest = None

    # setup device
    if cfg.cuda is None:
    if torch.cuda.is_available():
    map_location = torch.device('cuda:0')
    elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    logging.warning(
    "MPS device support is experimental. Set PYTORCH_ENABLE_MPS_FALLBACK=1 to avoid failures."
    )
    map_location = torch.device('mps')
    else:
    map_location = torch.device('cpu')
    elif cfg.cuda < 0:
    map_location = torch.device('cpu')
    else:
    map_location = torch.device(f'cuda:{cfg.cuda}')

    if cfg.compute_dtype is None:
    can_use_bfloat16 = map_location.type == "cuda" and torch.cuda.is_bf16_supported()
    compute_dtype = torch.bfloat16 if can_use_bfloat16 else torch.float32
    else:
    assert cfg.compute_dtype in {"float32", "bfloat16", "float16"}
    compute_dtype = getattr(torch, cfg.compute_dtype)

    logging.info(f"Inference device: {map_location}, dtype: {compute_dtype}")

    asr_model, model_name = setup_model(cfg, map_location)

    model_cfg = copy.deepcopy(asr_model._cfg)
    OmegaConf.set_struct(model_cfg.preprocessor, False)
    model_cfg.preprocessor.dither = 0.0
    model_cfg.preprocessor.pad_to = 0
    if model_cfg.preprocessor.normalize != "per_feature":
    logging.error("Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently")
    OmegaConf.set_struct(model_cfg.preprocessor, True)

    # No aggregate output file is computed or used.

    asr_model.freeze()
    asr_model = asr_model.to(asr_model.device)
    asr_model.to(compute_dtype)

    with open_dict(cfg.decoding):
    if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
    raise NotImplementedError("This script supports only `greedy_batch` with Label-Looping algorithm")
    cfg.decoding.preserve_alignments = True
    cfg.decoding.fused_batch_size = -1
    cfg.decoding.beam.return_best_hypothesis = True
    # Intentionally NOT setting compute_timestamps=True; we synthesize word times ourselves.

    if hasattr(asr_model, 'change_decoding_strategy'):
    if not isinstance(asr_model, EncDecRNNTModel) and not isinstance(asr_model, EncDecHybridRNNTCTCModel):
    raise ValueError("This script supports RNNT model and hybrid RNNT-CTC model with RNNT decoder")
    if isinstance(asr_model, EncDecRNNTModel):
    asr_model.change_decoding_strategy(cfg.decoding)
    if hasattr(asr_model, 'cur_decoder'):
    asr_model.change_decoding_strategy(cfg.decoding, decoder_type='rnnt')

    if manifest is not None:
    records = read_manifest(manifest)
    manifest_dir = Path(manifest).parent.absolute()
    for record in records:
    record["audio_filepath"] = str(filepath_to_absolute(record["audio_filepath"], manifest_dir))
    else:
    assert filepaths is not None
    records = [{"audio_filepath": audio_file} for audio_file in filepaths]

    asr_model.preprocessor.featurizer.dither = 0.0
    asr_model.preprocessor.featurizer.pad_to = 0
    asr_model.eval()

    decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer

    audio_sample_rate = model_cfg.preprocessor['sample_rate']
    feature_stride_sec = model_cfg.preprocessor['window_stride']
    features_per_sec = 1.0 / feature_stride_sec
    encoder_subsampling_factor = asr_model.encoder.subsampling_factor

    features_frame2audio_samples = make_divisible_by(
    int(audio_sample_rate * feature_stride_sec), factor=encoder_subsampling_factor
    )
    encoder_frame2audio_samples = features_frame2audio_samples * encoder_subsampling_factor

    # Accurate seconds per encoder step (accounts for divisibility correction)
    secs_per_step = encoder_frame2audio_samples / audio_sample_rate
    # Optional UI latency compensation (seconds). Set to cfg.right_context_secs if desired.
    latency_comp = 0.0

    context_encoder_frames = ContextSize(
    left=int(cfg.left_context_secs * features_per_sec / encoder_subsampling_factor),
    chunk=int(cfg.chunk_secs * features_per_sec / encoder_subsampling_factor),
    right=int(cfg.right_context_secs * features_per_sec / encoder_subsampling_factor),
    )
    context_samples = ContextSize(
    left=context_encoder_frames.left * encoder_subsampling_factor * features_frame2audio_samples,
    chunk=context_encoder_frames.chunk * encoder_subsampling_factor * features_frame2audio_samples,
    right=context_encoder_frames.right * encoder_subsampling_factor * features_frame2audio_samples,
    )

    logging.info(
    "Corrected contexts (sec): "
    f"Left {context_samples.left / audio_sample_rate:.2f}, "
    f"Chunk {context_samples.chunk / audio_sample_rate:.2f}, "
    f"Right {context_samples.right / audio_sample_rate:.2f}"
    )
    logging.info(
    f"Corrected contexts (subsampled encoder frames): Left {context_encoder_frames.left} - "
    f"Chunk {context_encoder_frames.chunk} - Right {context_encoder_frames.right}"
    )
    logging.info(
    f"Corrected contexts (in audio samples): Left {context_samples.left} - "
    f"Chunk {context_samples.chunk} - Right {context_samples.right}"
    )
    latency_secs = (context_samples.chunk + context_samples.right) / audio_sample_rate
    logging.info(f"Theoretical latency: {latency_secs:.2f} seconds")
    logging.info(f"secs_per_step (encoder): {secs_per_step:.6f} s")

    audio_dataset = SimpleAudioDataset(
    audio_filenames=[record["audio_filepath"] for record in records], sample_rate=audio_sample_rate
    )
    audio_dataloader = DataLoader(
    dataset=audio_dataset,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
    collate_fn=AudioBatch.collate_fn,
    drop_last=False,
    in_order=True,
    )

    with torch.no_grad(), torch.inference_mode():
    all_hyps = []
    for audio_data in tqdm(audio_dataloader):
    audio_batch = audio_data.audio_signals.to(device=map_location)
    audio_batch_lengths = audio_data.audio_signal_lengths.to(device=map_location)
    batch_size = audio_batch.shape[0]
    device = audio_batch.device

    current_batched_hyps: BatchedHyps | None = None
    state = None
    left_sample = 0
    right_sample = min(context_samples.chunk + context_samples.right, audio_batch.shape[1])
    buffer = StreamingBatchedAudioBuffer(
    batch_size=batch_size,
    context_samples=context_samples,
    dtype=audio_batch.dtype,
    device=device,
    )
    rest_audio_lengths = audio_batch_lengths.clone()

    while left_sample < audio_batch.shape[1]:
    chunk_length = min(right_sample, audio_batch.shape[1]) - left_sample
    is_last_chunk_batch = chunk_length >= rest_audio_lengths
    is_last_chunk = right_sample >= audio_batch.shape[1]
    chunk_lengths_batch = torch.where(
    is_last_chunk_batch,
    rest_audio_lengths,
    torch.full_like(rest_audio_lengths, fill_value=chunk_length),
    )
    buffer.add_audio_batch_(
    audio_batch[:, left_sample:right_sample],
    audio_lengths=chunk_lengths_batch,
    is_last_chunk=is_last_chunk,
    is_last_chunk_batch=is_last_chunk_batch,
    )

    encoder_output, encoder_output_len = asr_model(
    input_signal=buffer.samples,
    input_signal_length=buffer.context_size_batch.total(),
    )
    encoder_output = encoder_output.transpose(1, 2) # [B, T, C]
    encoder_context = buffer.context_size.subsample(factor=encoder_frame2audio_samples)
    encoder_context_batch = buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples)
    encoder_output = encoder_output[:, encoder_context.left :]

    chunk_batched_hyps, _, state = decoding_computer(
    x=encoder_output,
    out_len=encoder_context_batch.chunk,
    prev_batched_state=state,
    )
    if current_batched_hyps is None:
    current_batched_hyps = chunk_batched_hyps
    else:
    current_batched_hyps.merge_(chunk_batched_hyps)

    rest_audio_lengths -= chunk_lengths_batch
    left_sample = right_sample
    right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1])

    all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))

    for hyp in all_hyps:
    hyp.text = asr_model.tokenizer.ids_to_text(hyp.y_sequence.tolist())

    written_paths, pred_text_attr_name = write_transcription_custom(
    all_hyps,
    cfg,
    model_name,
    filepaths=filepaths,
    compute_langs=False,
    enable_timestamps=True,
    tokenizer=asr_model.tokenizer,
    secs_per_step=secs_per_step, # critical: convert step offsets -> seconds
    latency_comp=latency_comp, # optional UI compensation (default 0.0)
    )

    logging.info(f"Wrote {len(written_paths)} transcript files.")
    if written_paths:
    # Log up to first 5 paths for reference
    preview = "\n - ".join(written_paths[:5])
    logging.info(f"Sample outputs (up to 5):\n - {preview}")

    return cfg


    if __name__ == '__main__':
    main() # noqa pylint: disable=no-value-for-parameter