Last active
May 29, 2024 07:32
-
-
Save LukeAI/6af4984c79a7534c9c1330958545367c to your computer and use it in GitHub Desktop.
Revisions
-
LukeAI revised this gist
Jun 9, 2023 . 1 changed file with 44 additions and 15 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,21 +1,39 @@ #!/usr/bin/env python from __future__ import annotations import os from pathlib import Path from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor import cv2 import numpy as np import torch from tqdm import tqdm # config in_dir = 'my_images' out_dir = 'segmented' sam_model = "vit_l" sam_check = "sam_vit_l_0b3195.pth" #sam_model = "vit_h" #sam_check = "sam_vit_h_4b8939.pth" #sam_model = "vit_b" #sam_check = "sam_vit_b_01ec64.pth" device="cuda" transparency = 0.3 max_masks = 300 # sam generator params points_per_batch=64 points_per_side=64 pred_iou_thresh=0.86 stability_score_thresh=0.92 crop_n_layers=1 crop_n_points_downscale_factor=2 min_mask_region_area=100 # list of random colors colors = [] for i in range(max_masks): colors.append(np.random.random((3))) def draw_segmentation(anns): @@ -24,17 +42,18 @@ def draw_segmentation(anns): h, w = anns[0]['segmentation'].shape image = np.zeros((h, w, 3), dtype=np.float64) sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) no_masks = min(len(sorted_anns), max_masks) for i in range(no_masks): # true/false segmentation seg = sorted_anns[i]['segmentation'] # set this segmentation a random color image[seg] = colors[i] return image def process_image(img_path, out_path, mask_generator): image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # mask generator wants the default uint8 image @@ -45,7 +64,7 @@ def process_image(img_path, out_path, mask_generator): seg = draw_segmentation(masks) # add segmentation image on top of original image image += transparency * seg # convert back to uint8 for display/save image = (255 * image).astype(np.uint8) @@ -62,12 +81,22 @@ def process_image(img_path, out_path, mask_generator): # load SAM model + create mask generator sam = sam_model_registry[sam_model](checkpoint=sam_check) sam.to(device=device) sam = torch.compile(sam) mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=points_per_side, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh, crop_n_layers=crop_n_layers, crop_n_points_downscale_factor=crop_n_points_downscale_factor, min_mask_region_area=min_mask_region_area) # process input directory for img in tqdm(os.listdir(in_dir)): # change extension of output image to .png out_img = Path(img).stem + ".png" out_img = os.path.join(out_dir, out_img) # if we can read/decode this file as an image in_img = os.path.join(in_dir, img) if cv2.haveImageReader(in_img): process_image(in_img, out_img, mask_generator) -
LukeAI created this gist
Jun 9, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,73 @@ #!/usr/bin/env python from __future__ import annotations import os from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor import cv2 import numpy as np from tqdm import tqdm # config in_dir = 'my_images' out_dir = 'segmented' sam_model = "vit_l" sam_check = "sam_vit_l_0b3195.pth" # sam_model = "vit_h" # sam_check = "sam_vit_h_4b8939.pth" # sam_model = "vit_b" # sam_check = "sam_vit_b_01ec64.pth" device="cuda" def draw_segmentation(anns): if len(anns) == 0: return h, w = anns[0]['segmentation'].shape image = np.zeros((h, w, 3), dtype=np.float64) sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) for ann in sorted_anns: # true/false segmentation m = ann['segmentation'] # random color color_mask = np.random.random((3)) image[m] = color_mask return image def process_image(img_path, out_path, mask_generator): image = cv2.imread('cam_front_top_centre/1686044607746820638.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # mask generator wants the default uint8 image masks = mask_generator.generate(image) # convert to float64 image = image.astype(np.float64) / 255 seg = draw_segmentation(masks) # add segmentation image on top of original image image += 0.3 * seg # convert back to uint8 for display/save image = (255 * image).astype(np.uint8) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # cv2.imshow("my img", image) # cv2.waitKey(-1) cv2.imwrite(out_path, image) if __name__ == "__main__": # make sure output dir exists if not os.path.exists(out_dir): os.makedirs(out_dirs) # load SAM model + create mask generator sam = sam_model_registry[sam_model](checkpoint=sam_check) sam.to(device=device) mask_generator = SamAutomaticMaskGenerator(sam) # process input directory for img in tqdm(os.listdir(in_dir)): in_img = os.path.join(in_dir, img) out_img = os.path.join(out_dir, img) # if we can read/decode this file as an image if cv2.haveImageReader(in_img): process_image(in_img, out_img, mask_generator)