#!/usr/bin/env python3 import os, math, argparse from typing import List, Tuple, Optional, Dict import torch from pyannote.audio import Pipeline import srt import datetime as dt import webvtt # ---------- time helpers ---------- def to_srt_ts(seconds: float) -> dt.timedelta: if seconds is None: seconds = 0.0 ms = int(round((seconds - math.floor(seconds)) * 1000)) return dt.timedelta(seconds=int(seconds), milliseconds=ms) def from_vtt_timestamp(ts: str) -> float: # "HH:MM:SS.mmm" (WebVTT) h, m, s = ts.split(":") return int(h) * 3600 + int(m) * 60 + float(s) def to_vtt_timestamp(seconds: float) -> str: if seconds < 0: seconds = 0.0 h = int(seconds // 3600); m = int((seconds % 3600) // 60) s = seconds - (h * 3600 + m * 60) return f"{h:02d}:{m:02d}:{s:06.3f}" # ---------- caption I/O ---------- def read_srt(path: str) -> List[Dict]: with open(path, "r", encoding="utf-8") as f: subs = list(srt.parse(f.read())) items = [] for sub in subs: start = sub.start.total_seconds() end = sub.end.total_seconds() items.append({"index": sub.index, "start": start, "end": end, "text": sub.content}) return items def write_srt(path: str, items: List[Dict]) -> None: subs = [] for i, it in enumerate(items, start=1): subs.append( srt.Subtitle( index=i, start=to_srt_ts(it["start"]), end=to_srt_ts(it["end"]), content=it["text"], ) ) with open(path, "w", encoding="utf-8") as f: f.write(srt.compose(subs)) def read_vtt(path: str) -> List[Dict]: vtt = webvtt.read(path) items = [] for i, cue in enumerate(vtt, start=1): items.append({ "index": i, "start": from_vtt_timestamp(cue.start), "end": from_vtt_timestamp(cue.end), "text": cue.text }) return items def write_vtt(path: str, items: List[Dict]) -> None: vtt = webvtt.WebVTT() for it in items: cue = webvtt.Caption( start=to_vtt_timestamp(it["start"]), end=to_vtt_timestamp(it["end"]), text=it["text"] ) vtt.captions.append(cue) vtt.save(path) # ---------- diarization overlap ---------- def overlap(a: float, b: float, c: float, d: float) -> float: # length of intersection of [a,b] and [c,d] return max(0.0, min(b, d) - max(a, c)) def label_by_max_overlap(seg_start: float, seg_end: float, diar_segments: List[Tuple[float, float, str]]) -> Optional[str]: best, best_len = None, 0.0 for (s, e, lab) in diar_segments: ol = overlap(seg_start, seg_end, s, e) if ol > best_len: best_len = ol best = lab return best # ---------- main ---------- def main(): ap = argparse.ArgumentParser(description="Add speaker labels to existing SRT/VTT using pyannote diarization.") ap.add_argument("--audio_dir", required=True, help="Folder with .wav audio (names must match captions).") ap.add_argument("--captions_dir", required=True, help="Folder with .srt or .vtt files.") ap.add_argument("--out_dir", required=True) ap.add_argument("--hf_token", required=True, help="Hugging Face token to download the pipeline the first time.") ap.add_argument("--num_speakers", type=int, default=None) ap.add_argument("--min_speakers", type=int, default=None) ap.add_argument("--max_speakers", type=int, default=None) args = ap.parse_args() os.makedirs(args.out_dir, exist_ok=True) # Instantiate diarization pipeline (GPU optional) pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=args.hf_token) if torch.cuda.is_available(): pipe.to(torch.device("cuda")) diar_kwargs = {} if args.num_speakers is not None: diar_kwargs["num_speakers"] = args.num_speakers if args.min_speakers is not None: diar_kwargs["min_speakers"] = args.min_speakers if args.max_speakers is not None: diar_kwargs["max_speakers"] = args.max_speakers # Build index of captions by basename caption_map: Dict[str, Tuple[str, str]] = {} # base -> (path, kind) for fn in os.listdir(args.captions_dir): base, ext = os.path.splitext(fn) ext = ext.lower() if ext in (".srt", ".vtt"): caption_map[base] = (os.path.join(args.captions_dir, fn), ext[1:]) # Process each audio whose base has a caption for fn in sorted(os.listdir(args.audio_dir)): if not fn.lower().endswith(".wav"): continue base = os.path.splitext(fn)[0] if base not in caption_map: print(f"[skip] No caption found for {fn}") continue audio_path = os.path.join(args.audio_dir, fn) cap_path, kind = caption_map[base] print(f"[diarize] {fn} | captions: {os.path.basename(cap_path)} ({kind})") # Read captions items = read_srt(cap_path) if kind == "srt" else read_vtt(cap_path) # Run diarization (pyannote handles resampling + mono downmix automatically) diar = pipe(audio_path, **diar_kwargs) # Collect diarization segments diar_segments: List[Tuple[float, float, str]] = [] for turn, _, label in diar.itertracks(yield_label=True): diar_segments.append((turn.start, turn.end, label)) diar_segments.sort(key=lambda x: x[0]) # Tag each caption line with max-overlap speaker tagged = [] for it in items: speaker = label_by_max_overlap(it["start"], it["end"], diar_segments) prefix = f"[{speaker}] " if speaker else "" tagged.append({**it, "text": prefix + it["text"]}) # Write outputs out_srt = os.path.join(args.out_dir, f"{base}.diarized.srt") out_vtt = os.path.join(args.out_dir, f"{base}.diarized.vtt") if kind == "srt": write_srt(out_srt, tagged) write_vtt(out_vtt, tagged) else: write_vtt(out_vtt, tagged) write_srt(out_srt, tagged) # Also dump RTTM (standard diarization format) with open(os.path.join(args.out_dir, f"{base}.rttm"), "w", encoding="utf-8") as rttm: diar.write_rttm(rttm) print(f"[done] {base}: wrote *.diarized.srt, *.diarized.vtt, *.rttm") if __name__ == "__main__": main()