|
- import glob
- import json
- import os
- import configargparse
- import h5py
- import imageio
- import numpy as np
- import torch
- from tqdm import tqdm
- from PIL import Image
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
-
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
- def plot_anns(anns, palette):
- if len(anns) == 0:
- return
-
- img = np.zeros((anns[0]['segmentation'].shape[0], anns[0]['segmentation'].shape[1], 3),
- dtype=np.uint8)
- for i in range(len(palette)):
- m = anns[i]['segmentation'] # (H, W), bool
- color_mask = np.concatenate([palette[i]])
- img[m] = color_mask
-
- return img
-
-
- @torch.no_grad()
- def segment(args):
- scene_dir_path = os.path.join(args.dir_path, args.scene_id)
-
- image_dir_path = os.path.join(scene_dir_path, 'rgb')
- img_total_num = len(glob.glob(os.path.join(image_dir_path, 'rgb_*.png')))
- img_eval_interval = 5
- if args.split == "train":
- indices = list(range(0, img_total_num, img_eval_interval))
- elif args.split == "test":
- indices = list(range(img_eval_interval // 2, img_total_num, img_eval_interval))
-
- mask_gt_dir_path = os.path.join(scene_dir_path, 'semantic_instance')
- ins2label_path = os.path.join(args.dir_path, 'color_dict.json')
- ins2label = json.load((open(ins2label_path, 'r')))['replica'][args.scene_id]
- color_f = os.path.join(scene_dir_path, 'ins_rgb.hdf5')
- with h5py.File(color_f, 'r') as f:
- label_rgbs = f['datasets'][:] # ndarray
- f.close()
-
- scene_id_simple = args.scene_id.replace('_', '')
- suffix = 'vith_grid32_min0_intersection'
- os.makedirs(f'{args.save_dir_path}/{scene_id_simple}_{suffix}', exist_ok=True)
-
- for i in tqdm(range(len(indices)), desc='Segmenting image'):
- idx = indices[i]
- image_path = os.path.join(image_dir_path, f'rgb_{idx}.png')
- image = np.array(Image.open(image_path))
-
- sam = sam_model_registry[args.model_type](checkpoint=args.sam_checkpoint).to(device)
- mask_generator = SamAutomaticMaskGenerator(model=sam,
- points_per_side=16,
- min_mask_region_area=0)
- masks = mask_generator.generate(image)
-
- mask_gt_path = os.path.join(mask_gt_dir_path, f'semantic_instance_{idx}.png')
- mask_gt = np.array(Image.open(mask_gt_path))
- mask_gt_vis = np.zeros_like(image, dtype=np.uint8)
- anns_gt = []
- palette = []
- for mask_label in np.unique(mask_gt):
- mask_coord = np.where(mask_gt == mask_label)
- seg = np.full_like(mask_gt, False, dtype=bool)
- seg[mask_coord[0], mask_coord[1]] = True
- area = len(mask_coord[0])
-
- anns_gt.append({'label': mask_label,
- 'segmentation': seg,
- 'area': area})
-
- mask_gt_vis[mask_coord[0], mask_coord[1]] = label_rgbs[ins2label[str(mask_label)]]
-
- '''area'''
- # masks = sorted(masks, key=(lambda x: x['area']), # area: N_pixel
- # reverse=True)
- # sorted_anns_gt = sorted(anns_gt, key=(lambda x: x['area']), # area: N_pixel
- # reverse=True)
- # for ann_gt in sorted_anns_gt:
- # palette.append(label_rgbs[ins2label[str(ann_gt['label'])]])
-
- '''intersection'''
- for mask in masks:
- intersections = []
- mask_seg = mask['segmentation']
- for ann_gt in anns_gt:
- gt_seg = ann_gt['segmentation']
- intersections.append(len(np.where(mask_seg * gt_seg)[0]))
- label_idx = np.argmax(np.array(intersections))
- palette.append(label_rgbs[ins2label[str(anns_gt[label_idx]['label'])]])
-
- palette = np.array(palette)
- ins_map = plot_anns(masks, palette)
-
- imageio.imwrite(f'{args.save_dir_path}/{scene_id_simple}_{suffix}/{i:03d}_gt.png', mask_gt_vis)
- imageio.imwrite(f'{args.save_dir_path}/{scene_id_simple}_{suffix}/{i:03d}_ins.png', ins_map)
-
-
- if __name__ == '__main__':
- parser = configargparse.ArgumentParser()
- parser.add_argument('--sam_checkpoint', type=str,
- choices=['sam_vit_h_4b8939.pth', 'sam_vit_b_01ec64.pth', 'sam_vit_l_0b3195.pth'],
- default='sam_vit_h_4b8939.pth')
- parser.add_argument('--model_type', type=str, choices=['vit_h', 'vit_b', 'vit_l'],
- default='vit_h')
- parser.add_argument('--dir_path', type=str)
- parser.add_argument('--scene_id', type=str, choices=['office_0', 'office_2', 'office_3', 'office_4',
- 'room_0', 'room_1', 'room_2',
- 'all'])
- parser.add_argument('--split', type=str, choices=['train', 'test'])
- parser.add_argument('--save_dir_path', type=str)
-
- args = parser.parse_args()
-
- segment(args)
|