In [1]:
# !pip3.10 install huggingface-hub==0.23
# !pip3.10 install git+https://github.com/mesolitica/whisper-static-cache
# !pip3.10 uninstall torch -y; pip3.10 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

In [2]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline
from transformers.cache_utils import WhisperStaticCache
import torch
import requests
from datasets import Audio
from transformers import AutoProcessor
from tqdm import tqdm

sr = 16000
audio = Audio(sampling_rate=sr)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model_id = "openai/whisper-large-v3"
compute_dtype = torch.bfloat16
device = "cuda:0"

In [5]:
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) 
processor = AutoProcessor.from_pretrained(model_id)
_ = model.cuda()

Instantiating WhisperSdpaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
model_normal = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) 
_ = model_normal.cuda()

In [7]:
model.model.encoder.forward = torch.compile(model.model.encoder.forward, mode='reduce-overhead', fullgraph=True)

In [8]:
def decode_one_tokens(
    model, 
    proj_out, 
    cur_token, 
    past_key_values, 
    position_ids, 
    cache_position, 
    out_encoder,
):
    
    out_decoder = model(
        cur_token, 
        encoder_hidden_states=out_encoder,
        past_key_values = past_key_values,
        position_ids=position_ids,
        use_cache = True,
        return_dict = False,
        cache_position = cache_position
    )
    new_token = torch.argmax(proj_out(out_decoder[0][:,-1:]), dim=-1)
    return new_token

In [9]:
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)

In [10]:
r = requests.get('https://huggingface.co/datasets/huseinzol05/malaya-speech-stt-test-set/resolve/main/test.mp3')
y = audio.decode_example(audio.encode_example(r.content))['array']
r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav')
y2 = audio.decode_example(audio.encode_example(r.content))['array']

In [11]:
inputs = processor([y], return_tensors = 'pt').to('cuda')
inputs['input_features'] = inputs['input_features'].type(torch.bfloat16)

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [12]:
inputs2 = processor([y2], return_tensors = 'pt').to('cuda')
inputs2['input_features'] = inputs2['input_features'].type(torch.bfloat16)

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [13]:
# warming up

for _ in range(3):
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):
        out_encoder = model.model.encoder(inputs['input_features'])



In [14]:
%%time

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):
    out_encoder = model.model.encoder(inputs['input_features'])

CPU times: user 38.2 ms, sys: 324 µs, total: 38.5 ms
Wall time: 38 ms


In [15]:
%%time

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):
    out_encoder = model_normal.model.encoder(inputs['input_features'])

CPU times: user 44.1 ms, sys: 4.01 ms, total: 48.1 ms
Wall time: 47.4 ms


In [16]:
# warming up

with torch.no_grad():
    language = 'en'
    initial_strings = [
        '<|startoftranscript|>',
        f'<|{language}|>',
        '<|transcribe|>'
    ]

    labels = processor.tokenizer(
        ''.join(initial_strings), 
        add_special_tokens = False,
        return_tensors = 'pt',
    ).to('cuda')['input_ids']
    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=out_encoder[0],
        past_key_values = None,
        use_cache = True
    )
    past_key_values = out_decoder.past_key_values
    proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)
    out_encoder = out_encoder[0].clone()
    
    cache = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)
    seq_length = past_key_values[0][0].shape[2]
    cache_position = torch.tensor([seq_length], device=device)
    position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)
    
    for i in range(model.config.max_target_positions - len(initial_strings)):
        proj = decode_one_tokens(
            model.model.decoder, 
            model.proj_out, 
            proj.clone(), 
            cache, 
            position_ids,
            cache_position, 
            out_encoder
        )
        labels = torch.concat([labels, proj], axis = -1)
        position_ids += 1
        cache_position += 1

        if proj == model.config.eos_token_id:
            break



In [33]:
%%time

with torch.no_grad():
    
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):
        out_encoder = model.model.encoder(inputs['input_features'])
        
    language = 'en'
    initial_strings = [
        '<|startoftranscript|>',
        f'<|{language}|>',
        '<|transcribe|>'
    ]

    labels = processor.tokenizer(
        ''.join(initial_strings), 
        add_special_tokens = False,
        return_tensors = 'pt',
    ).to('cuda')['input_ids']
    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=out_encoder[0],
        past_key_values = None,
        use_cache = True
    )
    past_key_values = out_decoder.past_key_values
    proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)
    out_encoder = out_encoder[0].clone()

CPU times: user 50.5 ms, sys: 56 µs, total: 50.5 ms
Wall time: 49.7 ms


In [34]:
cache.reset(existing_cache = past_key_values)
seq_length = past_key_values[0][0].shape[2]
cache_position = torch.tensor([seq_length], device=device)
position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)

In [35]:
%%time

with torch.no_grad():
    for i in tqdm(range(model.config.max_target_positions - len(initial_strings))):
        proj = decode_one_tokens(
            model.model.decoder, 
            model.proj_out, 
            proj.clone(), 
            cache, 
            position_ids,
            cache_position, 
            out_encoder
        )
        labels = torch.concat([labels, proj], axis = -1)
        position_ids += 1
        cache_position += 1

        if proj == model.config.eos_token_id:
            break

 16%|████████████▉                                                                  | 73/445 [00:00<00:01, 186.26it/s]

CPU times: user 396 ms, sys: 3.8 ms, total: 400 ms
Wall time: 398 ms





In [36]:
processor.tokenizer.decode(labels[0])

'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'

In [41]:
%%time

with torch.no_grad():
    
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):
        out_encoder = model_normal.model.encoder(inputs['input_features'])
        
    language = 'en'
    initial_strings = [
        '<|startoftranscript|>',
        f'<|{language}|>',
        '<|transcribe|>'
    ]

    labels = processor.tokenizer(
        ''.join(initial_strings), 
        add_special_tokens = False,
        return_tensors = 'pt',
    ).to('cuda')['input_ids']
    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=out_encoder[0],
        past_key_values = None,
        use_cache = True
    )
    past_key_values = out_decoder.past_key_values
    proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)
    out_encoder = out_encoder[0].clone()

CPU times: user 54.5 ms, sys: 97 µs, total: 54.6 ms
Wall time: 54.3 ms


In [42]:
cache_normal = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)
seq_length = past_key_values[0][0].shape[2]
cache_position = torch.tensor([seq_length], device=device)
position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)

In [43]:
%%time

with torch.no_grad():
    for i in tqdm(range(model_normal.config.max_target_positions - len(initial_strings))):
        out_decoder = model_normal.model.decoder(
            proj, 
            encoder_hidden_states=out_encoder,
            past_key_values = cache_normal,
            position_ids=position_ids,
            use_cache = True,
            return_dict = False,
            cache_position = cache_position
        )
        proj = torch.argmax(model_normal.proj_out(out_decoder[0][:,-1:]), dim=-1)
        labels = torch.concat([labels, proj], axis = -1)
        position_ids += 1
        cache_position += 1

        if proj == model.config.eos_token_id:
            break

 16%|████████████▉                                                                  | 73/445 [00:00<00:02, 150.20it/s]

CPU times: user 491 ms, sys: 0 ns, total: 491 ms
Wall time: 490 ms





In [44]:
processor.tokenizer.decode(labels[0])

'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'