Skip to content

Instantly share code, notes, and snippets.

@lucapericlp
Created May 10, 2024 13:23
Show Gist options
  • Save lucapericlp/fedc6b1b7bff49ec3eb830560a054939 to your computer and use it in GitHub Desktop.
Save lucapericlp/fedc6b1b7bff49ec3eb830560a054939 to your computer and use it in GitHub Desktop.
Quickly calculate the PESQ between two audio files.
# coding: utf-8
import click
import torchaudio as ta
import torch
import librosa
from pesq import pesq
MAX_WAV_VALUE = 32768.0
@click.command()
@click.argument('gt_path', type=click.Path(exists=True))
@click.argument('pred_path', type=click.Path(exists=True))
@click.option('--src_rate', default=24000, help='Source sampling rate of the audio files.')
@click.option('--dest_rate', default=16000, help='Destination sampling rate to which the audio files will be resampled.')
def compare_files(gt_path, pred_path, src_rate, dest_rate):
"""
This script takes two audio file paths and optional sampling rates as arguments,
resamples them to a specified rate, and computes the PESQ score.
"""
# Set up the resampler
pesq_resampler = ta.transforms.Resample(src_rate, dest_rate).cuda()
# Load ground truth and prediction files
gt, _ = librosa.load(gt_path, sr=src_rate)
pred, _ = librosa.load(pred_path, sr=src_rate)
# Resample audio
pred_dest = pesq_resampler(torch.tensor(pred).cuda())
gt_dest = pesq_resampler(torch.tensor(gt).cuda())
# Convert to integer
gt_int_dest = (gt_dest * MAX_WAV_VALUE).short().cpu().numpy()
pred_int_dest = (pred_dest * MAX_WAV_VALUE).short().cpu().numpy()
# Compute PESQ
the_pesq = pesq(dest_rate, gt_int_dest, pred_int_dest, "wb")
print(f"PESQ Score: {the_pesq}")
if __name__ == '__main__':
compare_files()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment