Forked from ketanhdoshi/sound_classification_inference.py
Created
May 25, 2022 03:49
-
-
Save wac81/2b67373836c48e87d22b17e7ff74742c to your computer and use it in GitHub Desktop.
Sound Classification Inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ---------------------------- | |
| # Inference | |
| # ---------------------------- | |
| def inference (model, val_dl): | |
| correct_prediction = 0 | |
| total_prediction = 0 | |
| # Disable gradient updates | |
| with torch.no_grad(): | |
| for data in val_dl: | |
| # Get the input features and target labels, and put them on the GPU | |
| inputs, labels = data[0].to(device), data[1].to(device) | |
| # Normalize the inputs | |
| inputs_m, inputs_s = inputs.mean(), inputs.std() | |
| inputs = (inputs - inputs_m) / inputs_s | |
| # Get predictions | |
| outputs = model(inputs) | |
| # Get the predicted class with the highest score | |
| _, prediction = torch.max(outputs,1) | |
| # Count of predictions that matched the target label | |
| correct_prediction += (prediction == labels).sum().item() | |
| total_prediction += prediction.shape[0] | |
| acc = correct_prediction/total_prediction | |
| print(f'Accuracy: {acc:.2f}, Total items: {total_prediction}') | |
| # Run inference on trained model with the validation set | |
| inference(myModel, val_dl) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment