#!/usr/bin/env python from __future__ import annotations import os from pathlib import Path from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor import cv2 import numpy as np import torch from tqdm import tqdm # config in_dir = 'my_images' out_dir = 'segmented' sam_model = "vit_l" sam_check = "sam_vit_l_0b3195.pth" #sam_model = "vit_h" #sam_check = "sam_vit_h_4b8939.pth" #sam_model = "vit_b" #sam_check = "sam_vit_b_01ec64.pth" device="cuda" transparency = 0.3 max_masks = 300 # sam generator params points_per_batch=64 points_per_side=64 pred_iou_thresh=0.86 stability_score_thresh=0.92 crop_n_layers=1 crop_n_points_downscale_factor=2 min_mask_region_area=100 # list of random colors colors = [] for i in range(max_masks): colors.append(np.random.random((3))) def draw_segmentation(anns): if len(anns) == 0: return h, w = anns[0]['segmentation'].shape image = np.zeros((h, w, 3), dtype=np.float64) sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) no_masks = min(len(sorted_anns), max_masks) for i in range(no_masks): # true/false segmentation seg = sorted_anns[i]['segmentation'] # set this segmentation a random color image[seg] = colors[i] return image def process_image(img_path, out_path, mask_generator): image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # mask generator wants the default uint8 image masks = mask_generator.generate(image) # convert to float64 image = image.astype(np.float64) / 255 seg = draw_segmentation(masks) # add segmentation image on top of original image image += transparency * seg # convert back to uint8 for display/save image = (255 * image).astype(np.uint8) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # cv2.imshow("my img", image) # cv2.waitKey(-1) cv2.imwrite(out_path, image) if __name__ == "__main__": # make sure output dir exists if not os.path.exists(out_dir): os.makedirs(out_dirs) # load SAM model + create mask generator sam = sam_model_registry[sam_model](checkpoint=sam_check) sam.to(device=device) sam = torch.compile(sam) mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=points_per_side, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh, crop_n_layers=crop_n_layers, crop_n_points_downscale_factor=crop_n_points_downscale_factor, min_mask_region_area=min_mask_region_area) # process input directory for img in tqdm(os.listdir(in_dir)): # change extension of output image to .png out_img = Path(img).stem + ".png" out_img = os.path.join(out_dir, out_img) # if we can read/decode this file as an image in_img = os.path.join(in_dir, img) if cv2.haveImageReader(in_img): process_image(in_img, out_img, mask_generator)