Spaces:
Build error
Build error
| import gradio | |
| import gradio_image_annotation | |
| import gradio_imageslider | |
| import spaces | |
| import torch | |
| import src.SegmentAnything2Assist as SegmentAnything2Assist | |
| example_image_annotation = { | |
| "image": "assets/cars.jpg", | |
| "boxes": [{'label': '+', 'color': (0, 255, 0), 'xmin': 886, 'ymin': 551, 'xmax': 886, 'ymax': 551}, {'label': '-', 'color': (255, 0, 0), 'xmin': 1239, 'ymin': 576, 'xmax': 1239, 'ymax': 576}, {'label': '-', 'color': (255, 0, 0), 'xmin': 610, 'ymin': 574, 'xmax': 610, 'ymax': 574}, {'label': '', 'color': (0, 0, 255), 'xmin': 254, 'ymin': 466, 'xmax': 1347, 'ymax': 1047}] | |
| } | |
| VERBOSE = True | |
| segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = "sam2_hiera_tiny", device = torch.device("cuda")) | |
| __image_point_coords = None | |
| __image_point_labels = None | |
| __image_box = None | |
| __current_mask = None | |
| __current_segment = None | |
| def __change_base_model(model_name, device): | |
| global segment_anything2assist | |
| try: | |
| segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = model_name, device = torch.device(device)) | |
| gradio.Info(f"Model changed to {model_name} on {device}", duration = 5) | |
| except: | |
| gradio.Error(f"Model could not be changed", duration = 5) | |
| def __post_process_annotator_inputs(value): | |
| global __image_point_coords, __image_point_labels, __image_box | |
| global __current_mask, __current_segment | |
| if VERBOSE: | |
| print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.") | |
| __current_mask, __current_segment = None, None | |
| new_boxes = [] | |
| __image_point_coords = [] | |
| __image_point_labels = [] | |
| __image_box = [] | |
| b_has_box = False | |
| for box in value["boxes"]: | |
| if box['label'] == '': | |
| if not b_has_box: | |
| new_box = box.copy() | |
| new_box['color'] = (0, 0, 255) | |
| new_boxes.append(new_box) | |
| b_has_box = True | |
| __image_box = [ | |
| box['xmin'], | |
| box['ymin'], | |
| box['xmax'], | |
| box['ymax'] | |
| ] | |
| elif box['label'] == '+' or box['label'] == '-': | |
| new_box = box.copy() | |
| new_box['color'] = (0, 255, 0) if box['label'] == '+' else (255, 0, 0) | |
| new_box['xmin'] = int((box['xmin'] + box['xmax']) / 2) | |
| new_box['ymin'] = int((box['ymin'] + box['ymax']) / 2) | |
| new_box['xmax'] = new_box['xmin'] | |
| new_box['ymax'] = new_box['ymin'] | |
| new_boxes.append(new_box) | |
| __image_point_coords.append([new_box['xmin'], new_box['ymin']]) | |
| __image_point_labels.append(1 if box['label'] == '+' else 0) | |
| if len(__image_box) == 0: | |
| __image_box = None | |
| if len(__image_point_coords) == 0: | |
| __image_point_coords = None | |
| if len(__image_point_labels) == 0: | |
| __image_point_labels = None | |
| if VERBOSE: | |
| print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.") | |
| def __generate_mask(value, mask_threshold, max_hole_area, max_sprinkle_area, image_output_mode): | |
| global __current_mask, __current_segment | |
| global __image_point_coords, __image_point_labels, __image_box | |
| global segment_anything2assist | |
| # Force post processing of annotated image | |
| __post_process_annotator_inputs(value) | |
| if VERBOSE: | |
| print("SegmentAnything2AssistApp::__generate_mask::Called.") | |
| mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image( | |
| value["image"], | |
| __image_point_coords, | |
| __image_point_labels, | |
| __image_box, | |
| mask_threshold, | |
| max_hole_area, | |
| max_sprinkle_area | |
| ) | |
| if VERBOSE: | |
| print("SegmentAnything2AssistApp::__generate_mask::Masks generated.") | |
| __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(value["image"], mask_chw[0]) | |
| if VERBOSE: | |
| print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.") | |
| if image_output_mode == "Mask": | |
| return [value["image"], __current_mask] | |
| elif image_output_mode == "Segment": | |
| return [value["image"], __current_segment] | |
| else: | |
| gradio.Warning("This is an issue, please report the problem!", duration=5) | |
| return gradio_imageslider.ImageSlider(render = True) | |
| def __change_output_mode(image_input, radio): | |
| global __current_mask, __current_segment | |
| global __image_point_coords, __image_point_labels, __image_box | |
| if VERBOSE: | |
| print("SegmentAnything2AssistApp::__generate_mask::Called.") | |
| if __current_mask is None or __current_segment is None: | |
| gradio.Warning("Configuration was changed, generate the mask again", duration=5) | |
| return gradio_imageslider.ImageSlider(render = True) | |
| if radio == "Mask": | |
| return [image_input["image"], __current_mask] | |
| elif radio == "Segment": | |
| return [image_input["image"], __current_segment] | |
| else: | |
| gradio.Warning("This is an issue, please report the problem!", duration=5) | |
| return gradio_imageslider.ImageSlider(render = True) | |
| def __generate_multi_mask_output(image, auto_list, auto_mode, auto_bbox_mode): | |
| global segment_anything2assist | |
| image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image(image, [int(i) - 1 for i in auto_list]) | |
| output_1 = image_with_bbox if auto_bbox_mode else image | |
| output_2 = mask if auto_mode == "Mask" else segment | |
| return [output_1, output_2] | |
| def __generate_auto_mask( | |
| image, | |
| points_per_side, | |
| points_per_batch, | |
| pred_iou_thresh, | |
| stability_score_thresh, | |
| stability_score_offset, | |
| mask_threshold, | |
| box_nms_thresh, | |
| crop_n_layers, | |
| crop_nms_thresh, | |
| crop_overlay_ratio, | |
| crop_n_points_downscale_factor, | |
| min_mask_region_area, | |
| use_m2m, | |
| multimask_output, | |
| output_mode | |
| ): | |
| global segment_anything2assist | |
| if VERBOSE: | |
| print("SegmentAnything2AssistApp::__generate_auto_mask::Called.") | |
| __auto_masks = segment_anything2assist.generate_automatic_masks( | |
| image, | |
| points_per_side, | |
| points_per_batch, | |
| pred_iou_thresh, | |
| stability_score_thresh, | |
| stability_score_offset, | |
| mask_threshold, | |
| box_nms_thresh, | |
| crop_n_layers, | |
| crop_nms_thresh, | |
| crop_overlay_ratio, | |
| crop_n_points_downscale_factor, | |
| min_mask_region_area, | |
| use_m2m, | |
| multimask_output | |
| ) | |
| if len(__auto_masks) == 0: | |
| gradio.Warning("No masks generated, please tweak the advanced parameters.", duration = 5) | |
| return gradio_imageslider.ImageSlider(), \ | |
| gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False), \ | |
| gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False) | |
| else: | |
| choices = [str(i) for i in range(len(__auto_masks))] | |
| returning_image = __generate_multi_mask_output(image, ["0"], output_mode, False) | |
| return returning_image, \ | |
| gradio.CheckboxGroup(choices, value = ["0"], label = "Mask List", interactive = True), \ | |
| gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = True) | |
| with gradio.Blocks() as base_app: | |
| gradio.Markdown("# SegmentAnything2Assist") | |
| with gradio.Row(): | |
| with gradio.Column(): | |
| base_model_choice = gradio.Dropdown( | |
| ['sam2_hiera_large', 'sam2_hiera_small', 'sam2_hiera_base_plus','sam2_hiera_tiny'], | |
| value = 'sam2_hiera_tiny', | |
| label = "Model Choice" | |
| ) | |
| with gradio.Column(): | |
| base_gpu_choice = gradio.Dropdown( | |
| ['cpu', 'cuda'], | |
| value = 'cuda', | |
| label = "Device Choice" | |
| ) | |
| base_model_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice]) | |
| base_gpu_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice]) | |
| with gradio.Tab(label = "Image Segmentation", id = "image_tab") as image_tab: | |
| gradio.Markdown("Image Segmentation", render = True) | |
| with gradio.Column(): | |
| with gradio.Accordion("Image Annotation Documentation", open = False): | |
| gradio.Markdown(""" | |
| Image annotation allows you to mark specific regions of an image with labels. | |
| In this app, you can annotate an image by drawing boxes and assigning labels to them. | |
| The labels can be either '+' or '-'. | |
| To annotate an image, simply click and drag to draw a box around the desired region. | |
| You can add multiple boxes with different labels. | |
| Once you have annotated the image, click the 'Generate Mask' button to generate a mask based on the annotations. | |
| The mask can be either a binary mask or a segmented mask, depending on the selected output mode. | |
| You can switch between the output modes using the radio buttons. | |
| If you make any changes to the annotations or the output mode, you need to regenerate the mask by clicking the button again. | |
| Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area. | |
| These options control the sensitivity and accuracy of the segmentation process. | |
| Experiment with different settings to achieve the desired results. | |
| """) | |
| image_input = gradio_image_annotation.image_annotator(example_image_annotation) | |
| with gradio.Accordion("Advanced Options", open = False): | |
| image_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "SAM Mask Threshold") | |
| image_generate_SAM_max_hole_area = gradio.Slider(0, 1000, 0, label = "SAM Max Hole Area") | |
| image_generate_SAM_max_sprinkle_area = gradio.Slider(0, 1000, 0, label = "SAM Max Sprinkle Area") | |
| image_generate_mask_button = gradio.Button("Generate Mask") | |
| image_output = gradio_imageslider.ImageSlider() | |
| image_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode") | |
| image_input.change(__post_process_annotator_inputs, inputs = [image_input]) | |
| image_generate_mask_button.click(__generate_mask, inputs = [ | |
| image_input, | |
| image_generate_SAM_mask_threshold, | |
| image_generate_SAM_max_hole_area, | |
| image_generate_SAM_max_sprinkle_area, | |
| image_output_mode | |
| ], | |
| outputs = [image_output]) | |
| image_output_mode.change(__change_output_mode, inputs = [image_input, image_output_mode], outputs = [image_output]) | |
| with gradio.Tab(label = "Auto Segmentation", id = "auto_tab"): | |
| gradio.Markdown("Auto Segmentation", render = True) | |
| with gradio.Column(): | |
| with gradio.Accordion("Auto Annotation Documentation", open = False): | |
| gradio.Markdown(""" | |
| """) | |
| auto_input = gradio.Image("assets/cars.jpg") | |
| with gradio.Accordion("Advanced Options", open = False): | |
| auto_generate_SAM_points_per_side = gradio.Slider(1, 64, 32, 1, label = "Points Per Side", interactive = True) | |
| auto_generate_SAM_points_per_batch = gradio.Slider(1, 64, 32, 1, label = "Points Per Batch", interactive = True) | |
| auto_generate_SAM_pred_iou_thresh = gradio.Slider(0.0, 1.0, 0.8, 1, label = "Pred IOU Threshold", interactive = True) | |
| auto_generate_SAM_stability_score_thresh = gradio.Slider(0.0, 1.0, 0.95, label = "Stability Score Threshold", interactive = True) | |
| auto_generate_SAM_stability_score_offset = gradio.Slider(0.0, 1.0, 1.0, label = "Stability Score Offset", interactive = True) | |
| auto_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "Mask Threshold", interactive = True) | |
| auto_generate_SAM_box_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Box NMS Threshold", interactive = True) | |
| auto_generate_SAM_crop_n_layers = gradio.Slider(0, 10, 0, 1, label = "Crop N Layers", interactive = True) | |
| auto_generate_SAM_crop_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Crop NMS Threshold", interactive = True) | |
| auto_generate_SAM_crop_overlay_ratio = gradio.Slider(0.0, 1.0, 512 / 1500, label = "Crop Overlay Ratio", interactive = True) | |
| auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider(1, 10, 1, label = "Crop N Points Downscale Factor", interactive = True) | |
| auto_generate_SAM_min_mask_region_area = gradio.Slider(0, 1000, 0, label = "Min Mask Region Area", interactive = True) | |
| auto_generate_SAM_use_m2m = gradio.Checkbox(label = "Use M2M", interactive = True) | |
| auto_generate_SAM_multimask_output = gradio.Checkbox(value = True, label = "Multi Mask Output", interactive = True) | |
| auto_generate_button = gradio.Button("Generate Auto Mask") | |
| with gradio.Row(): | |
| with gradio.Column(): | |
| auto_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode", interactive = True) | |
| auto_output_list = gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False) | |
| auto_output_bbox = gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False) | |
| with gradio.Column(scale = 3): | |
| auto_output = gradio_imageslider.ImageSlider() | |
| auto_generate_button.click( | |
| __generate_auto_mask, | |
| inputs = [ | |
| auto_input, | |
| auto_generate_SAM_points_per_side, | |
| auto_generate_SAM_points_per_batch, | |
| auto_generate_SAM_pred_iou_thresh, | |
| auto_generate_SAM_stability_score_thresh, | |
| auto_generate_SAM_stability_score_offset, | |
| auto_generate_SAM_mask_threshold, | |
| auto_generate_SAM_box_nms_thresh, | |
| auto_generate_SAM_crop_n_layers, | |
| auto_generate_SAM_crop_nms_thresh, | |
| auto_generate_SAM_crop_overlay_ratio, | |
| auto_generate_SAM_crop_n_points_downscale_factor, | |
| auto_generate_SAM_min_mask_region_area, | |
| auto_generate_SAM_use_m2m, | |
| auto_generate_SAM_multimask_output, | |
| auto_output_mode | |
| ], | |
| outputs = [ | |
| auto_output, | |
| auto_output_list, | |
| auto_output_bbox | |
| ] | |
| ) | |
| auto_output_list.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) | |
| auto_output_bbox.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) | |
| auto_output_mode.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) | |
| if __name__ == "__main__": | |
| base_app.launch() | |