HACARUS Tech Blog: Segment Anything Tutorial

HACARUS Tech Blog: Segment Anything Tutorial

What is Segment Anything?

The Segment-Anything Model is a high-performance segmentation model for extracting objects of interest in an image. This model allows you to specify the objects of interest through annotations, click points or bounding boxes, to obtain versatile and high-quality segmentation results. The flow of segmentation using SAM is shown in the figure below.

SAM Architecture

The SAM architecture converts images into tensors using an image encoder, converts point and bounding box prompts into tensors using a prompt encoder, and feeds these into a mask decoder to generate the desired segmentation mask.

If you are interested in the the details, please check out the official GitHub or paper.

The following steps outline the basic steps to getting started with SAM (more details on each subject in this blog):

  1. Environment Setup: Configure your Python environment and install the required libraries, such as PyTorch and opencv-python, as well as SAM specific dependencies.
  2. Segmentation by Point prompt: Learn how to perform SAM segmentation by specifying points on an image.
  3. Segmentation by Bounding Box prompt: Learn how to specify a bounding box for an image and perform SAM segmentation.
  4. Bounding Box Standardization: As an application example of SAM, we introduce an idea to improve annotation quality. Here, we introduce a method to automatically standardize the bounding box size variation among annotators.

How to setup SAM

  1. Prepare an environment for Python 3.8 or later;
  2. Install the following libraries;
    1. pip install opencv-python matplotlib
    2. pip install git+https://github.com/facebookresearch/segment-anything.git
  3. Inference in notebooks
    1. Download a checkpoint for the trained model. Select the appropriate one from SAM checkpoints.
    2. Load the model using the downloaded checkpoint. This typically requires information defining the model structure and the path to the checkpoint file.
    3. Encode the image, which involves some pre-processing, to convert it into a format the model can understand.
    4. Enter annotation information, such as points or boxes, to specify regions of interest (ROIs) in the image.
    5. Overlay the segmentation mask generated by the model on the image to visually see how it recognizes certain regions.

Set Drawing Functions

Useful functions for drawing are defined as follows:

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax, edgecolor='green'):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2))  

Image Preparation

Prepare sample images:

image = cv2.imread('images/sushi.jpg')
# resize the image to the expected size if you need
image = cv2.resize(image, (800, 600))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(8,8))
plt.imshow(image)
plt.axis('on')
plt.show()

Preparation for Inference

Load the model to infer the above image.

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"

# device = "cuda"
device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

predictor.set_image(image) 

Prompt 1: Specify one point

To segment a particular object in an image (such as a piece of sushi), you can provide the coordinates (x, y) of a point in the image as a prompt. Each point is labeled with a 1 or 0 to indicate whether it is foreground (part of the object of interest, in this case Salmon) or background (not part of the object).

# Set salmon coordinates as prompt
input_point = np.array([[530, 150]])
input_label = np.array([1])

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

To actually run inference on the above prompt, set multimask_output parameter to True (default), and SAM will generate three masks as output. These masks represent different interpretations of the image based on the prompt, and you can get different masks if the prompt is ambiguous. You will also get a score that represents the model’s confidence in each mask.

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
) 
# masks shape: (3, 600, 800): (number_of_masks) x H x W

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()

Prompt 2: Prompt to specify multiple points

SAM also allows you to give multiple points as prompts at the same time. As an example, the following is the inference results when you specify one point from each sushi.

points = np.array([
	[530, 150], 
	[370, 180], 
	[250, 180], 
	[300, 300], 
	[520, 290], 
	[600, 250]
])
labels = np.array([1, 1, 1, 1, 1, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask
masks, _, _ = predictor.predict(
    point_coords=points,
    point_labels=labels,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(8,8))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(points, labels, plt.gca())
plt.show()

Prompt 3: Prompt to specify background with label

When specifying a point, you need to pass additional information, such as foreground or background, as the label corresponding to the point (1 for foreground, 0 for background). In the next code section, we specify the salmon part as the foreground and the rice part as the background, to try and extract only the salmon part.

# Only a single point, no background specified
input_point = np.array([[530, 150]])
input_label = np.array([1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.show() 

# When specifying the rice part as the background
input_point = np.array([[530, 150], [500, 200]])
input_label = np.array([1, 0])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.show() 

Prompt 4: Prompt to specify a Bounding Box

SAM also allows you to specify the bounding box coordinates in the format x0, y0, x1, y1 as a prompt:

input_box = np.array([370, 70, 630, 260]) # [x0, y0, x1, y1]

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.show()

Prompt 5: Prompt to specify a bounding box and a point simultaneously

You can give a prompt combining both bounding boxes and points, as shown below. Here, we’ll specify a bounding box and background points.

input_box = np.array([390, 80, 600, 250]) # [x0, y0, x1, y1]
input_point = np.array([[500, 200]])
input_label = np.array([0])

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_box,
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
# plt.axis('off')
plt.show()

SAM application example: Unifying bounding box sizes

When creating annotation data, the input of each annotator may be an issue. For example, in the case of object detection tasks, the size of the bounding box varies from annotator to annotator, even for the same object. Below, we will introduce how to use SAM to improve annotation quality by correcting to make accurate (and uniform) bounding boxes. In particular, a tighter bounding box can be obtained by assuming that the bounding box specified in prompt 4 is based on the annotator input, and then calculating a new bounding box from the mask obtained as a result of SAM.

# find the box of the mask
mask = masks[0]
y, x = np.where(mask)
y0, y1 = y.min(), y.max()
x0, x1 = x.min(), x.max()
shrink_box = np.array([x0, y0, x1, y1])

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_box(shrink_box, plt.gca(), edgecolor='red')
# plt.axis('off')
plt.show()

Green bounding box given by annotator, and the red tighter bounding box as modified by SAM.

Conclusion

Using Segment Anything, we were able to easily obtain high-quality segmentation results using prompts based on points or bounding boxes, which can be further enhanced. SAM has been applied to a variety of areas, such as medical images and CVAT, and is expected to continue to develop in the future.

References

Subscribe to our newsletter

Click here to sign up