Created
August 6, 2025 04:05
-
-
Save huseinzol05/94f263b26f01fddf4e0c1f25e9dc4cdc to your computer and use it in GitHub Desktop.
Accurate force alignment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/darshan.r/synthetic-dia/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import torchaudio\n", | |
| "import torch\n", | |
| "import soundfile as sf\n", | |
| "from ctc_forced_aligner import (\n", | |
| " load_audio,\n", | |
| " load_alignment_model,\n", | |
| " generate_emissions,\n", | |
| " preprocess_text,\n", | |
| " get_alignments,\n", | |
| " get_spans,\n", | |
| " postprocess_results,\n", | |
| ")\n", | |
| "device = 'cuda'\n", | |
| "alignment_model, alignment_tokenizer = load_alignment_model(\n", | |
| " device,\n", | |
| " dtype=torch.float16 if device == \"cuda\" else torch.float32,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/example-speaker/93113.mp3\n", | |
| "gen_text = \"The tomato turned red because it was ripe! It's a natural process as the tomato matures and becomes ready to eat.\"\n", | |
| "language = 'ms'\n", | |
| "y, sr = sf.read('93113.mp3')\n", | |
| "new_wav = torch.from_numpy(y)\n", | |
| "audio_waveform = torchaudio.functional.resample(\n", | |
| " new_wav, orig_freq=44100, new_freq=16000\n", | |
| ").type(torch.float16).cuda()\n", | |
| "emissions, stride = generate_emissions(\n", | |
| " alignment_model, audio_waveform, batch_size=1\n", | |
| ")\n", | |
| "tokens_starred, text_starred = preprocess_text(\n", | |
| " gen_text,\n", | |
| " romanize=True,\n", | |
| " language=language,\n", | |
| ")\n", | |
| "tokens_starred.append('<star>')\n", | |
| "text_starred.append('<star>')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "segments, scores, blank_token = get_alignments(\n", | |
| " emissions,\n", | |
| " tokens_starred,\n", | |
| " alignment_tokenizer,\n", | |
| ")\n", | |
| "spans = get_spans(tokens_starred, segments, blank_token)\n", | |
| "word_timestamps = postprocess_results(text_starred, spans, stride, scores)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokenizer = alignment_tokenizer\n", | |
| "tokens = tokens_starred\n", | |
| "dictionary = tokenizer.get_vocab()\n", | |
| "dictionary = {k: v for k, v in dictionary.items()}\n", | |
| "dictionary_rev = {v: k for k, v in dictionary.items()}\n", | |
| "dictionary[\"<star>\"] = len(dictionary)\n", | |
| "blank_id = dictionary.get(\"<blank>\", tokenizer.pad_token_id)\n", | |
| "\n", | |
| "token_indices = [\n", | |
| " dictionary[c] for c in \" \".join(tokens).split(\" \") if c in dictionary\n", | |
| "]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "alignments, scores = torchaudio.functional.forced_align(emissions[None].cpu(), torch.tensor([token_indices]))\n", | |
| "alignments, scores = alignments[0], scores[0]\n", | |
| "token_spans = torchaudio.functional.merge_tokens(alignments, scores)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "groups = []\n", | |
| "current_group = []\n", | |
| "\n", | |
| "for span in token_spans:\n", | |
| " current_group.append(span)\n", | |
| " if span.token == 31:\n", | |
| " groups.append(current_group)\n", | |
| " current_group = []\n", | |
| "\n", | |
| "if len(current_group):\n", | |
| " groups.append(current_group)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[{'text': 'the',\n", | |
| " 'start': 23,\n", | |
| " 'end': 27,\n", | |
| " 'score': np.float64(-0.6402994791666666)},\n", | |
| " {'text': 'tomato',\n", | |
| " 'start': 28,\n", | |
| " 'end': 47,\n", | |
| " 'score': np.float64(-0.4833577473958333)},\n", | |
| " {'text': 'turned',\n", | |
| " 'start': 50,\n", | |
| " 'end': 65,\n", | |
| " 'score': np.float64(-0.6833902994791666)},\n", | |
| " {'text': 'red', 'start': 66, 'end': 79, 'score': np.float64(-0.312255859375)},\n", | |
| " {'text': 'because',\n", | |
| " 'start': 83,\n", | |
| " 'end': 98,\n", | |
| " 'score': np.float64(-0.23644147600446427)},\n", | |
| " {'text': 'it',\n", | |
| " 'start': 99,\n", | |
| " 'end': 101,\n", | |
| " 'score': np.float64(-2.0826568603515625)},\n", | |
| " {'text': 'was',\n", | |
| " 'start': 102,\n", | |
| " 'end': 111,\n", | |
| " 'score': np.float64(-0.04644775390625)},\n", | |
| " {'text': 'ripe',\n", | |
| " 'start': 116,\n", | |
| " 'end': 147,\n", | |
| " 'score': np.float64(-0.10152626037597656)},\n", | |
| " {'text': \"it's\",\n", | |
| " 'start': 163,\n", | |
| " 'end': 168,\n", | |
| " 'score': np.float64(-0.2649574279785156)},\n", | |
| " {'text': 'a',\n", | |
| " 'start': 169,\n", | |
| " 'end': 172,\n", | |
| " 'score': np.float64(-0.0223541259765625)},\n", | |
| " {'text': 'natural',\n", | |
| " 'start': 174,\n", | |
| " 'end': 190,\n", | |
| " 'score': np.float64(-0.03738205773489816)},\n", | |
| " {'text': 'process',\n", | |
| " 'start': 191,\n", | |
| " 'end': 213,\n", | |
| " 'score': np.float64(-0.4380640302385603)},\n", | |
| " {'text': 'as',\n", | |
| " 'start': 216,\n", | |
| " 'end': 221,\n", | |
| " 'score': np.float64(-0.01033782958984375)},\n", | |
| " {'text': 'the',\n", | |
| " 'start': 222,\n", | |
| " 'end': 227,\n", | |
| " 'score': np.float64(-0.08664449055989583)},\n", | |
| " {'text': 'tomato',\n", | |
| " 'start': 228,\n", | |
| " 'end': 245,\n", | |
| " 'score': np.float64(-0.16687647501627603)},\n", | |
| " {'text': 'matures',\n", | |
| " 'start': 247,\n", | |
| " 'end': 277,\n", | |
| " 'score': np.float64(-0.30379159109933035)},\n", | |
| " {'text': 'and',\n", | |
| " 'start': 279,\n", | |
| " 'end': 284,\n", | |
| " 'score': np.float64(-0.009943644205729166)},\n", | |
| " {'text': 'becomes',\n", | |
| " 'start': 285,\n", | |
| " 'end': 301,\n", | |
| " 'score': np.float64(-0.005782740456717355)},\n", | |
| " {'text': 'ready',\n", | |
| " 'start': 302,\n", | |
| " 'end': 312,\n", | |
| " 'score': np.float64(-0.00582733154296875)},\n", | |
| " {'text': 'to',\n", | |
| " 'start': 314,\n", | |
| " 'end': 319,\n", | |
| " 'score': np.float64(-0.0202484130859375)},\n", | |
| " {'text': 'eat',\n", | |
| " 'start': 322,\n", | |
| " 'end': 329,\n", | |
| " 'score': np.float64(-0.00466156005859375)}]" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import numpy as np\n", | |
| "\n", | |
| "merged = []\n", | |
| "for g in groups:\n", | |
| " temp, score = [], []\n", | |
| " if not isinstance(g, list):\n", | |
| " g = [g]\n", | |
| " if len(g) == 1 and g[0].token == 31:\n", | |
| " continue\n", | |
| " for g_ in g:\n", | |
| " if g_.token == 31:\n", | |
| " continue\n", | |
| " score.append(g_.score)\n", | |
| " temp.append(dictionary_rev[g_.token])\n", | |
| " if len(temp):\n", | |
| " merged.append({\n", | |
| " 'text': ''.join(temp),\n", | |
| " 'start': g[0].start,\n", | |
| " 'end': g[-1].start + (g[-1].end - g[-1].start) // 2,\n", | |
| " 'score': np.mean(score)\n", | |
| " })\n", | |
| "merged" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(21, 21)" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "text_nonstar = [w for w in text_starred if w != '<star>']\n", | |
| "len(merged), len(text_nonstar)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[{'text': 'The',\n", | |
| " 'start': 0.45880519480519477,\n", | |
| " 'end': 0.5385974025974026,\n", | |
| " 'score': -0.6402994791666666},\n", | |
| " {'text': 'tomato',\n", | |
| " 'start': 0.5585454545454546,\n", | |
| " 'end': 0.9375584415584415,\n", | |
| " 'score': -0.4833577473958333},\n", | |
| " {'text': 'turned',\n", | |
| " 'start': 0.9974025974025974,\n", | |
| " 'end': 1.2966233766233768,\n", | |
| " 'score': -0.6833902994791666},\n", | |
| " {'text': 'red',\n", | |
| " 'start': 1.3165714285714287,\n", | |
| " 'end': 1.5758961038961041,\n", | |
| " 'score': -0.312255859375},\n", | |
| " {'text': 'because',\n", | |
| " 'start': 1.6556883116883114,\n", | |
| " 'end': 1.954909090909091,\n", | |
| " 'score': -0.23644147600446427},\n", | |
| " {'text': 'it',\n", | |
| " 'start': 1.9748571428571429,\n", | |
| " 'end': 2.0147532467532465,\n", | |
| " 'score': -2.0826568603515625},\n", | |
| " {'text': 'was',\n", | |
| " 'start': 2.0347012987012985,\n", | |
| " 'end': 2.2142337662337663,\n", | |
| " 'score': -0.04644775390625},\n", | |
| " {'text': 'ripe!',\n", | |
| " 'start': 2.313974025974026,\n", | |
| " 'end': 2.932363636363636,\n", | |
| " 'score': -0.10152626037597656},\n", | |
| " {'text': \"It's\",\n", | |
| " 'start': 3.2515324675324675,\n", | |
| " 'end': 3.351272727272727,\n", | |
| " 'score': -0.2649574279785156},\n", | |
| " {'text': 'a',\n", | |
| " 'start': 3.371220779220779,\n", | |
| " 'end': 3.431064935064935,\n", | |
| " 'score': -0.0223541259765625},\n", | |
| " {'text': 'natural',\n", | |
| " 'start': 3.470961038961039,\n", | |
| " 'end': 3.79012987012987,\n", | |
| " 'score': -0.03738205773489816},\n", | |
| " {'text': 'process',\n", | |
| " 'start': 3.810077922077922,\n", | |
| " 'end': 4.248935064935065,\n", | |
| " 'score': -0.4380640302385603},\n", | |
| " {'text': 'as',\n", | |
| " 'start': 4.308779220779221,\n", | |
| " 'end': 4.40851948051948,\n", | |
| " 'score': -0.01033782958984375},\n", | |
| " {'text': 'the',\n", | |
| " 'start': 4.428467532467533,\n", | |
| " 'end': 4.528207792207792,\n", | |
| " 'score': -0.08664449055989583},\n", | |
| " {'text': 'tomato',\n", | |
| " 'start': 4.548155844155844,\n", | |
| " 'end': 4.887272727272727,\n", | |
| " 'score': -0.16687647501627603},\n", | |
| " {'text': 'matures',\n", | |
| " 'start': 4.927168831168831,\n", | |
| " 'end': 5.52561038961039,\n", | |
| " 'score': -0.30379159109933035},\n", | |
| " {'text': 'and',\n", | |
| " 'start': 5.5655064935064935,\n", | |
| " 'end': 5.665246753246753,\n", | |
| " 'score': -0.009943644205729166},\n", | |
| " {'text': 'becomes',\n", | |
| " 'start': 5.685194805194805,\n", | |
| " 'end': 6.004363636363636,\n", | |
| " 'score': -0.005782740456717355},\n", | |
| " {'text': 'ready',\n", | |
| " 'start': 6.0243116883116885,\n", | |
| " 'end': 6.223792207792209,\n", | |
| " 'score': -0.00582733154296875},\n", | |
| " {'text': 'to',\n", | |
| " 'start': 6.263688311688312,\n", | |
| " 'end': 6.363428571428572,\n", | |
| " 'score': -0.0202484130859375},\n", | |
| " {'text': 'eat.',\n", | |
| " 'start': 6.423272727272727,\n", | |
| " 'end': 6.56290909090909,\n", | |
| " 'score': -0.00466156005859375}]" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "alignment = []\n", | |
| "for i in range(len(merged)):\n", | |
| " alignment.append({\n", | |
| " 'text': text_nonstar[i],\n", | |
| " 'start': (merged[i]['start'] * len(y) / emissions.shape[0]) / sr,\n", | |
| " 'end': (merged[i]['end'] * len(y) / emissions.shape[0]) / sr,\n", | |
| " 'score': float(merged[i]['score'])\n", | |
| " })\n", | |
| "alignment" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <audio controls=\"controls\" >\n", | |
| " <source src=\"data:audio/wav;base64,\" type=\"audio/wav\" />\n", | |
| " Your browser does not support the audio element.\n", | |
| " </audio>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.lib.display.Audio object>" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import IPython.display as ipd\n", | |
| "ipd.Audio(y[int(alignment[8]['start'] * sr): int(alignment[8]['end'] * sr)], rate = sr)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 {'text': 'The', 'start': 0.45880519480519477, 'end': 0.5385974025974026, 'score': -0.6402994791666666}\n", | |
| "1 {'text': 'tomato', 'start': 0.5585454545454546, 'end': 0.9375584415584415, 'score': -0.4833577473958333}\n", | |
| "2 {'text': 'turned', 'start': 0.9974025974025974, 'end': 1.2966233766233768, 'score': -0.6833902994791666}\n", | |
| "3 {'text': 'red', 'start': 1.3165714285714287, 'end': 1.5758961038961041, 'score': -0.312255859375}\n", | |
| "4 {'text': 'because', 'start': 1.6556883116883114, 'end': 1.954909090909091, 'score': -0.23644147600446427}\n", | |
| "5 {'text': 'it', 'start': 1.9748571428571429, 'end': 2.0147532467532465, 'score': -2.0826568603515625}\n", | |
| "6 {'text': 'was', 'start': 2.0347012987012985, 'end': 2.2142337662337663, 'score': -0.04644775390625}\n", | |
| "7 {'text': 'ripe!', 'start': 2.313974025974026, 'end': 2.932363636363636, 'score': -0.10152626037597656}\n", | |
| "8 {'text': \"It's\", 'start': 3.2515324675324675, 'end': 3.351272727272727, 'score': -0.2649574279785156}\n", | |
| "9 {'text': 'a', 'start': 3.371220779220779, 'end': 3.431064935064935, 'score': -0.0223541259765625}\n", | |
| "10 {'text': 'natural', 'start': 3.470961038961039, 'end': 3.79012987012987, 'score': -0.03738205773489816}\n", | |
| "11 {'text': 'process', 'start': 3.810077922077922, 'end': 4.248935064935065, 'score': -0.4380640302385603}\n", | |
| "12 {'text': 'as', 'start': 4.308779220779221, 'end': 4.40851948051948, 'score': -0.01033782958984375}\n", | |
| "13 {'text': 'the', 'start': 4.428467532467533, 'end': 4.528207792207792, 'score': -0.08664449055989583}\n", | |
| "14 {'text': 'tomato', 'start': 4.548155844155844, 'end': 4.887272727272727, 'score': -0.16687647501627603}\n", | |
| "15 {'text': 'matures', 'start': 4.927168831168831, 'end': 5.52561038961039, 'score': -0.30379159109933035}\n", | |
| "16 {'text': 'and', 'start': 5.5655064935064935, 'end': 5.665246753246753, 'score': -0.009943644205729166}\n", | |
| "17 {'text': 'becomes', 'start': 5.685194805194805, 'end': 6.004363636363636, 'score': -0.005782740456717355}\n", | |
| "18 {'text': 'ready', 'start': 6.0243116883116885, 'end': 6.223792207792209, 'score': -0.00582733154296875}\n", | |
| "19 {'text': 'to', 'start': 6.263688311688312, 'end': 6.363428571428572, 'score': -0.0202484130859375}\n", | |
| "20 {'text': 'eat.', 'start': 6.423272727272727, 'end': 6.56290909090909, 'score': -0.00466156005859375}\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for i in range(len(alignment)):\n", | |
| " print(i, alignment[i])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "dia", | |
| "language": "python", | |
| "name": "dia" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment