From 4f3d4d50d6b69c1fe0728682d03f93e43960835f Mon Sep 17 00:00:00 2001 From: TrungDinhT Date: Fri, 20 Feb 2026 18:49:49 +0100 Subject: [PATCH] Fix box non max suppression and add test --- tests/test_utils/test_bounding_box_utils.py | 106 ++++++++++++++++++++ yolo/utils/bounding_box_utils.py | 2 +- 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/tests/test_utils/test_bounding_box_utils.py b/tests/test_utils/test_bounding_box_utils.py index f5d2f0e63..e7ed9a7c0 100644 --- a/tests/test_utils/test_bounding_box_utils.py +++ b/tests/test_utils/test_bounding_box_utils.py @@ -220,6 +220,112 @@ def test_bbox_nms(): assert allclose(out, exp, atol=1e-4), f"Output: {out} Expected: {exp}" +def test_bbox_nms_float16_precision(): + """ + Test that bbox_nms correctly handles float16 inputs with large coordinates. + + The bug: batched_nms internally shifts boxes by (label * max_coord) to separate groups. + With float16, large coordinates (~3500) combined with large idxs (computed by + `batch_idx + valid_cls * bbox.size(0)` in box_nms) cause precision loss, + making overlapping boxes appear non-overlapping → NMS fails to suppress duplicates. + + This test ensures such an extreme cases don't break box_nms implementation. + """ + + # Large coordinates simulating real-world high-res image detections (~3500 range) + # Two clusters of heavily overlapping boxes per image, per class + # These SHOULD be suppressed to 1 box per cluster per class + cls_dist = torch.tensor( + [ + [ + # anchor 0-7: cluster A (x≈3428), high conf class 2 + [-10, -10, 2.0], + [-10, -10, 0.8], + [-10, -10, 0.5], + [-10, -10, 0.2], + # anchor 4-7: cluster B (x≈2056), high conf class 2 + [-10, -10, 0.8], + [-10, -10, 1.6], + [-10, -10, 0.3], + [-10, -10, 0.2], + ] + ] * 8, + dtype=torch.float16, + ).to("cuda") + + bbox = torch.tensor( + [ + [ + # Cluster A: tightly overlapping boxes around x≈3428, y≈85-300 + # IoU between any pair >> 0.5, should suppress to 1 + [3428.0, 85.0625, 3500.0, 298.7500], + [3428.0, 85.9375, 3500.0, 295.5000], + [3428.0, 93.0625, 3500.0, 294.0000], + [3428.0, 92.1875, 3500.0, 293.0000], + # Cluster B: tightly overlapping boxes around x≈2056, y≈756-918 + # IoU between any pair >> 0.5, should suppress to 1 + # IoU between cluster A and B = 0.0 (non-overlapping) → both kept + [2056.0, 757.0000, 2392.0, 917.5000], + [2054.0, 756.5000, 2392.0, 918.0000], + [2058.0, 756.0000, 2392.0, 916.5000], + [2054.0, 756.0000, 2392.0, 915.5000], + ] + ] * 8, + dtype=torch.float16, + ).to("cuda") + + nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5, max_bbox=400) + + # Expected: for each image, class 2 has 2 objects (cluster A and cluster B) + # → exactly 2 boxes per image should survive, both class 2 + # The highest scoring box from each cluster is kept (sigmoid of 2.0 and 1.6) + # + # If float16 bug is present: NMS fails to suppress within clusters + # → more boxes survive per image instead of 2 + output = bbox_nms(cls_dist, bbox, nms_cfg) + + for batch_i, result in enumerate(output): + num_kept = result.shape[0] + assert num_kept == 2, ( + f"Image {batch_i}: expected 2 boxes (1 per cluster), got {num_kept}. " + f"Float16 precision bug likely causing NMS to fail suppression.\n" + f"Kept boxes:\n{result}" + ) + + kept_classes = result[:, 0] + assert (kept_classes == 2).all(), ( + f"Image {batch_i}: expected all kept boxes to be class 2, got {kept_classes}" + ) + + kept_boxes = result[:, 1:5] + + # One box should be from cluster A (x1≈3426-3428) + # One box should be from cluster B (x1≈2054-2058) + cluster_a_mask = kept_boxes[:, 0] > 3000 + cluster_b_mask = kept_boxes[:, 0] < 3000 + assert cluster_a_mask.sum() == 1, ( + f"Image {batch_i}: expected exactly 1 box from cluster A (x≈3428), " + f"got {cluster_a_mask.sum()}" + ) + assert cluster_b_mask.sum() == 1, ( + f"Image {batch_i}: expected exactly 1 box from cluster B (x≈2056), " + f"got {cluster_b_mask.sum()}" + ) + + # Highest scoring box from each cluster should be kept (sigmoid of 2.0 and 1.6) + kept_scores = result[:, 5] + expected_score_a = torch.tensor(2.0, dtype=torch.float16).sigmoid().item() + expected_score_b = torch.tensor(1.6, dtype=torch.float16).sigmoid().item() + assert abs(kept_scores[cluster_a_mask].item() - expected_score_a) < 1e-2, ( + f"Image {batch_i}: cluster A score mismatch. " + f"Got {kept_scores[cluster_a_mask].item():.4f}, expected {expected_score_a:.4f}" + ) + assert abs(kept_scores[cluster_b_mask].item() - expected_score_b) < 1e-2, ( + f"Image {batch_i}: cluster B score mismatch. " + f"Got {kept_scores[cluster_b_mask].item():.4f}, expected {expected_score_b:.4f}" + ) + + def test_calculate_map(): predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]]) # [class, x1, y1, x2, y2] ground_truths = tensor([[0, 50, 50, 150, 150], [0, 30, 30, 100, 100]]) # [class, x1, y1, x2, y2] diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 0357bfdce..252b9f55e 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -464,7 +464,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt valid_con = cls_dist[batch_idx, valid_grid, valid_cls] valid_box = bbox[batch_idx, valid_grid] - nms_idx = batched_nms(valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou) + nms_idx = batched_nms(valid_box.float(), valid_con.float(), batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou) predicts_nms = [] for idx in range(cls_dist.size(0)): instance_idx = nms_idx[idx == batch_idx[nms_idx]]