Skip to content

Instantly share code, notes, and snippets.

@LukeAI
Last active May 29, 2024 07:32
Show Gist options
  • Select an option

  • Save LukeAI/6af4984c79a7534c9c1330958545367c to your computer and use it in GitHub Desktop.

Select an option

Save LukeAI/6af4984c79a7534c9c1330958545367c to your computer and use it in GitHub Desktop.

Revisions

  1. LukeAI revised this gist Jun 9, 2023. 1 changed file with 44 additions and 15 deletions.
    59 changes: 44 additions & 15 deletions batch_sam.py
    Original file line number Diff line number Diff line change
    @@ -1,21 +1,39 @@
    #!/usr/bin/env python
    from __future__ import annotations
    import os
    from pathlib import Path
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
    import cv2
    import numpy as np
    import torch
    from tqdm import tqdm

    # config
    in_dir = 'my_images'
    out_dir = 'segmented'
    sam_model = "vit_l"
    sam_check = "sam_vit_l_0b3195.pth"
    # sam_model = "vit_h"
    # sam_check = "sam_vit_h_4b8939.pth"
    # sam_model = "vit_b"
    # sam_check = "sam_vit_b_01ec64.pth"
    #sam_model = "vit_h"
    #sam_check = "sam_vit_h_4b8939.pth"
    #sam_model = "vit_b"
    #sam_check = "sam_vit_b_01ec64.pth"
    device="cuda"
    transparency = 0.3
    max_masks = 300

    # sam generator params
    points_per_batch=64
    points_per_side=64
    pred_iou_thresh=0.86
    stability_score_thresh=0.92
    crop_n_layers=1
    crop_n_points_downscale_factor=2
    min_mask_region_area=100

    # list of random colors
    colors = []
    for i in range(max_masks):
    colors.append(np.random.random((3)))


    def draw_segmentation(anns):
    @@ -24,17 +42,18 @@ def draw_segmentation(anns):
    h, w = anns[0]['segmentation'].shape
    image = np.zeros((h, w, 3), dtype=np.float64)
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    for ann in sorted_anns:
    no_masks = min(len(sorted_anns), max_masks)
    for i in range(no_masks):
    # true/false segmentation
    m = ann['segmentation']
    # random color
    color_mask = np.random.random((3))
    image[m] = color_mask
    seg = sorted_anns[i]['segmentation']

    # set this segmentation a random color
    image[seg] = colors[i]
    return image


    def process_image(img_path, out_path, mask_generator):
    image = cv2.imread('cam_front_top_centre/1686044607746820638.jpg')
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # mask generator wants the default uint8 image
    @@ -45,7 +64,7 @@ def process_image(img_path, out_path, mask_generator):
    seg = draw_segmentation(masks)

    # add segmentation image on top of original image
    image += 0.3 * seg
    image += transparency * seg

    # convert back to uint8 for display/save
    image = (255 * image).astype(np.uint8)
    @@ -62,12 +81,22 @@ def process_image(img_path, out_path, mask_generator):
    # load SAM model + create mask generator
    sam = sam_model_registry[sam_model](checkpoint=sam_check)
    sam.to(device=device)
    mask_generator = SamAutomaticMaskGenerator(sam)

    sam = torch.compile(sam)
    mask_generator = SamAutomaticMaskGenerator(sam,
    points_per_side=points_per_side,
    pred_iou_thresh=pred_iou_thresh,
    stability_score_thresh=stability_score_thresh,
    crop_n_layers=crop_n_layers,
    crop_n_points_downscale_factor=crop_n_points_downscale_factor,
    min_mask_region_area=min_mask_region_area)
    # process input directory
    for img in tqdm(os.listdir(in_dir)):
    in_img = os.path.join(in_dir, img)
    out_img = os.path.join(out_dir, img)

    # change extension of output image to .png
    out_img = Path(img).stem + ".png"
    out_img = os.path.join(out_dir, out_img)

    # if we can read/decode this file as an image
    in_img = os.path.join(in_dir, img)
    if cv2.haveImageReader(in_img):
    process_image(in_img, out_img, mask_generator)
  2. LukeAI created this gist Jun 9, 2023.
    73 changes: 73 additions & 0 deletions batch_sam.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,73 @@
    #!/usr/bin/env python
    from __future__ import annotations
    import os
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
    import cv2
    import numpy as np
    from tqdm import tqdm

    # config
    in_dir = 'my_images'
    out_dir = 'segmented'
    sam_model = "vit_l"
    sam_check = "sam_vit_l_0b3195.pth"
    # sam_model = "vit_h"
    # sam_check = "sam_vit_h_4b8939.pth"
    # sam_model = "vit_b"
    # sam_check = "sam_vit_b_01ec64.pth"
    device="cuda"


    def draw_segmentation(anns):
    if len(anns) == 0:
    return
    h, w = anns[0]['segmentation'].shape
    image = np.zeros((h, w, 3), dtype=np.float64)
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    for ann in sorted_anns:
    # true/false segmentation
    m = ann['segmentation']
    # random color
    color_mask = np.random.random((3))
    image[m] = color_mask
    return image


    def process_image(img_path, out_path, mask_generator):
    image = cv2.imread('cam_front_top_centre/1686044607746820638.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # mask generator wants the default uint8 image
    masks = mask_generator.generate(image)

    # convert to float64
    image = image.astype(np.float64) / 255
    seg = draw_segmentation(masks)

    # add segmentation image on top of original image
    image += 0.3 * seg

    # convert back to uint8 for display/save
    image = (255 * image).astype(np.uint8)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    # cv2.imshow("my img", image)
    # cv2.waitKey(-1)
    cv2.imwrite(out_path, image)

    if __name__ == "__main__":
    # make sure output dir exists
    if not os.path.exists(out_dir):
    os.makedirs(out_dirs)

    # load SAM model + create mask generator
    sam = sam_model_registry[sam_model](checkpoint=sam_check)
    sam.to(device=device)
    mask_generator = SamAutomaticMaskGenerator(sam)

    # process input directory
    for img in tqdm(os.listdir(in_dir)):
    in_img = os.path.join(in_dir, img)
    out_img = os.path.join(out_dir, img)
    # if we can read/decode this file as an image
    if cv2.haveImageReader(in_img):
    process_image(in_img, out_img, mask_generator)