Skip to content

Instantly share code, notes, and snippets.

@chamecall
Created March 25, 2025 13:19
Show Gist options
  • Save chamecall/cee5dc2ad31a49a75e658fde50388e62 to your computer and use it in GitHub Desktop.
Save chamecall/cee5dc2ad31a49a75e658fde50388e62 to your computer and use it in GitHub Desktop.

Revisions

  1. chamecall created this gist Mar 25, 2025.
    84 changes: 84 additions & 0 deletions TopBiasedRandomCrop.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,84 @@
    import numpy as np
    import cv2
    from albumentations.core.transforms_interface import DualTransform

    class TopBiasedRandomCrop(DualTransform):
    """
    Randomly crops an image so that the cropped region's width is at least min_width and its height is at least min_height
    of the original. The vertical (y) offset is biased by sampling from a Beta distribution with parameters beta_alpha and beta_beta.
    This transform also adjusts bounding boxes in [x_min, y_min, x_max, y_max] format.
    Args:
    min_width (float): Minimum relative width of the crop (0 < min_width <= 1).
    min_height (float): Minimum relative height of the crop (0 < min_height <= 1).
    beta_alpha (float): Alpha parameter for the Beta distribution (vertical bias).
    beta_beta (float): Beta parameter for the Beta distribution.
    p (float): Probability of applying the transform.
    """
    def __init__(self, min_width=0.7, min_height=0.5, beta_alpha=1.0, beta_beta=2.0, p=1.0):
    super(TopBiasedRandomCrop, self).__init__(p)
    if not (0 < min_width <= 1):
    raise ValueError("min_width must be in the interval (0, 1].")
    if not (0 < min_height <= 1):
    raise ValueError("min_height must be in the interval (0, 1].")
    self.min_width = min_width
    self.min_height = min_height
    self.beta_alpha = beta_alpha
    self.beta_beta = beta_beta

    def get_params_dependent_on_data(self, params, data):
    return self.get_params_dependent_on_targets({"image": data["image"]})

    def get_params_dependent_on_targets(self, params) -> dict:
    img = params["image"]
    height, width = img.shape[:2]

    # Determine crop dimensions.
    crop_width = int(np.random.uniform(self.min_width, 1.0) * width)
    crop_height = int(np.random.uniform(self.min_height, 1.0) * height)
    crop_width = min(crop_width, width)
    crop_height = min(crop_height, height)

    # Maximum possible offsets.
    x_max = width - crop_width
    y_max = height - crop_height

    x1 = np.random.randint(0, x_max + 1) if x_max > 0 else 0
    y_sample = np.random.beta(self.beta_alpha, self.beta_beta)
    y1 = int(y_sample * y_max) if y_max > 0 else 0

    crop_params = [x1, y1, x1 + crop_width, y1 + crop_height]
    # Return crop_params plus update new shape info (so that bbox filtering uses the cropped dimensions).
    return {"crop_params": crop_params, "rows": crop_height, "cols": crop_width}

    def apply(self, img, **params):
    crop_params = params.get("crop_params")
    if crop_params is None:
    return img
    x1, y1, x2, y2 = crop_params
    cropped = img[y1:y2, x1:x2]
    return cropped

    def apply_to_bbox(self, bbox, **params):
    crop_params = params.get("crop_params")
    if crop_params is None:
    return bbox
    x1, y1, x2, y2 = crop_params
    new_bbox = [
    np.clip(bbox[0] - x1, 0, x2 - x1),
    np.clip(bbox[1] - y1, 0, y2 - y1),
    np.clip(bbox[2] - x1, 0, x2 - x1),
    np.clip(bbox[3] - y1, 0, y2 - y1)
    ]
    if len(bbox) > 4:
    new_bbox.extend(bbox[4:])
    return new_bbox

    def apply_to_bboxes(self, bboxes, **params):
    transformed = [self.apply_to_bbox(bbox, **params) for bbox in bboxes]
    # Convert to NumPy array so that further processing (e.g., filtering) works.
    return np.array(transformed, dtype=np.float32)

    def get_transform_init_args_names(self):
    return ("min_width", "min_height", "beta_alpha", "beta_beta")