Last active
September 5, 2023 11:59
-
-
Save sukhitashvili/745b7528f0418172e84d4ddb23f666b9 to your computer and use it in GitHub Desktop.
non maximum suppression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| def get_iou(pred1: torch.Tensor, pred2: torch.Tensor): | |
| """ | |
| prediction: [x1, y1, x2, y2, conf, cls] | |
| """ | |
| ar1 = abs((pred1[2] - pred1[0]) * (pred1[3] - pred1[1])) | |
| ar2 = abs((pred2[2] - pred2[0]) * (pred2[3] - pred2[1])) | |
| x1 = torch.max(pred1[0], pred2[0]) | |
| y1 = torch.max(pred1[1], pred2[1]) | |
| x2 = torch.min(pred1[2], pred2[2]) | |
| y2 = torch.min(pred1[3], pred2[3]) | |
| inter = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) # clamp to zero if they don't intersect | |
| iou = inter / (ar1 + ar2 - inter) | |
| return iou | |
| def non_max_supression(predictions: np.ndarray, conf_thr: float, iou_thr: float) -> np.ndarray: | |
| """ | |
| Does NMS on predictions. | |
| Args: | |
| predictions: [[x1, y1, x2, y2, conf, cls], ...] | |
| conf_thr: float | |
| iou_thr: float | |
| Returns: | |
| predictions: [[x1, y1, x2, y2, conf, cls], ...] | |
| """ | |
| res_predictions = [pred for pred in predictions if pred[4] >= conf_thr] | |
| res_predictions = sorted(res_predictions, key=lambda x: x[4], reverse=True) | |
| if len(res_predictions) == 0: | |
| return predictions | |
| results = [] | |
| while len(res_predictions) > 0: | |
| cur_pred = res_predictions.pop(0) | |
| # recreate boxes | |
| res_predictions = [ | |
| pred for pred in res_predictions | |
| if | |
| (get_iou(torch.tensor(pred), torch.tensor(cur_pred)).item() < iou_thr) or (pred[4] != cur_pred[4]) | |
| ] | |
| results.append(cur_pred) | |
| results = np.array(results) | |
| return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment