Skip to content

Instantly share code, notes, and snippets.

@sukhitashvili
Last active September 5, 2023 11:59
Show Gist options
  • Save sukhitashvili/745b7528f0418172e84d4ddb23f666b9 to your computer and use it in GitHub Desktop.
Save sukhitashvili/745b7528f0418172e84d4ddb23f666b9 to your computer and use it in GitHub Desktop.
non maximum suppression
import torch
def get_iou(pred1: torch.Tensor, pred2: torch.Tensor):
"""
prediction: [x1, y1, x2, y2, conf, cls]
"""
ar1 = abs((pred1[2] - pred1[0]) * (pred1[3] - pred1[1]))
ar2 = abs((pred2[2] - pred2[0]) * (pred2[3] - pred2[1]))
x1 = torch.max(pred1[0], pred2[0])
y1 = torch.max(pred1[1], pred2[1])
x2 = torch.min(pred1[2], pred2[2])
y2 = torch.min(pred1[3], pred2[3])
inter = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) # clamp to zero if they don't intersect
iou = inter / (ar1 + ar2 - inter)
return iou
def non_max_supression(predictions: np.ndarray, conf_thr: float, iou_thr: float) -> np.ndarray:
"""
Does NMS on predictions.
Args:
predictions: [[x1, y1, x2, y2, conf, cls], ...]
conf_thr: float
iou_thr: float
Returns:
predictions: [[x1, y1, x2, y2, conf, cls], ...]
"""
res_predictions = [pred for pred in predictions if pred[4] >= conf_thr]
res_predictions = sorted(res_predictions, key=lambda x: x[4], reverse=True)
if len(res_predictions) == 0:
return predictions
results = []
while len(res_predictions) > 0:
cur_pred = res_predictions.pop(0)
# recreate boxes
res_predictions = [
pred for pred in res_predictions
if
(get_iou(torch.tensor(pred), torch.tensor(cur_pred)).item() < iou_thr) or (pred[4] != cur_pred[4])
]
results.append(cur_pred)
results = np.array(results)
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment