數(shù)字中國建設峰會 官方網(wǎng)站seo引擎優(yōu)化方案
【圖片分割】【深度學習】Windows10下SAM官方代碼Pytorch實現(xiàn)
提示:最近開始在【圖片分割】方面進行研究,記錄相關知識點,分享學習中遇到的問題已經(jīng)解決的方法。
文章目錄
- 【圖片分割】【深度學習】Windows10下SAM官方代碼Pytorch實現(xiàn)
- 前言
- SAM模型運行環(huán)境安裝
- 打開cmd,執(zhí)行下面的指令查看CUDA版本號
- 2.安裝GPU版本的torch:【[官網(wǎng)](https://pytorch.org/)】
- 3.博主安裝環(huán)境參考
- SAM代碼使用
- predictor_example
- 步驟一:查看測試圖片
- 步驟二:顯示前景和背景的標記點
- 步驟三:標記點完成前景目標的分割
- 步驟四:標定框完成前景目標的分割
- 步驟五:標定框和標記點聯(lián)合完成前景目標的分割
- 步驟六:多標定框完成前景目標的分割
- 步驟六:圖片批量完成前景目標的分割
- automatic_mask_generator_example
- 步驟一:自動掩碼生成
- 步驟一:自動掩碼生成參數(shù)調(diào)整
- 總結
前言
SAM是由谷歌的Kirillov, Alexander等人在《Segment Anything》【論文地址】一文中提出的模型,模塊化交互式VOS(MiVOS)框架將交互到掩碼和掩碼傳播分離,從而實現(xiàn)更高的泛化性和更好的性能。
在詳細解析SAM網(wǎng)絡之前,首要任務是搭建SAM【Pytorch-demo地址】所需的運行環(huán)境,并模型完成訓練和測試工作,展開后續(xù)工作才有意義。
SAM模型運行環(huán)境安裝
代碼運行這里提了要求,python要大于等于3.8,pytorch大于等于1.7,torchvision大于等于0.8。
打開cmd,執(zhí)行下面的指令查看CUDA版本號
nvidia-smi
2.安裝GPU版本的torch:【官網(wǎng)】
博主的cuda版本是12.1,但這里cuda版本最高也是11.8,博主選的11.7也沒問題。
其他cuda版本的torch在【以前版本】找對應的安裝命令,其他包安裝就用githup源碼教程給出的方式安裝即可:
cd segment-anything; pip install -e .
# 這里是選裝包,但是博主還是都裝了
pip install opencv-python pycocotools matplotlib onnxruntime onnx
3.博主安裝環(huán)境參考
根據(jù)個人電腦配置環(huán)境運行環(huán)境,這里博主提供了本人運行環(huán)境安裝的包,假設你的cuda版本是11.7及其以上,個人覺得可以直接用博主的yaml安裝。
# 使用Anaconda導出環(huán)境yaml文件(這步是博主導出自己的安裝包,可忽略)
conda env export --name SAM >environment.yaml
# 使用yaml創(chuàng)建虛擬環(huán)境
conda env create -f environment.yaml
conda下載超時自動斷開處理方法
#把連接超時的時間設置成100s,讀取超時的時間修改成100s
conda config --set remote_connect_timeout_secs 100
conda config --set remote_read_timeout_secs 100
environment.yml文件內(nèi)容,注意"name: SAM"是自定義虛擬環(huán)境名。假如有些包實現(xiàn)安裝不了,單獨pip安裝,這里只是作為一個參考。
name: SAM
channels:- conda-forge- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/- defaults
dependencies:- bzip2=1.0.8=h8ffe710_4- ca-certificates=2022.12.7=h5b45459_0- libffi=3.4.2=h8ffe710_5- libsqlite=3.40.0=hcfcfb64_0- libzlib=1.2.13=hcfcfb64_4- openssl=3.1.0=hcfcfb64_0- pip=23.0.1=pyhd8ed1ab_0- python=3.9.16=h4de0772_0_cpython- setuptools=67.6.1=pyhd8ed1ab_0- tk=8.6.12=h8ffe710_0- tzdata=2023c=h71feb2d_0- ucrt=10.0.22621.0=h57928b3_0- vc=14.3=hb6edc58_10- vs2015_runtime=14.34.31931=h4c5c07a_10- wheel=0.40.0=pyhd8ed1ab_0- xz=5.2.6=h8d14728_0- pip:- certifi==2022.12.7- charset-normalizer==2.1.1- coloredlogs==15.0.1- contourpy==1.0.7- cycler==0.11.0- filelock==3.9.0- flatbuffers==23.3.3- fonttools==4.39.3- humanfriendly==10.0- idna==3.4- importlib-resources==5.12.0- jinja2==3.1.2- kiwisolver==1.4.4- markupsafe==2.1.2- matplotlib==3.7.1- mpmath==1.3.0- networkx==3.0- numpy==1.24.2- onnx==1.13.1- onnxruntime==1.14.1- opencv-python==4.7.0.72- packaging==23.0- pillow==9.5.0- protobuf==3.20.3- pycocotools==2.0.6- pyparsing==3.0.9- pyreadline3==3.4.1- python-dateutil==2.8.2- requests==2.28.1- six==1.16.0- sympy==1.11.1- torch==2.0.0+cu117- torchaudio==2.0.1+cu117- torchvision==0.15.1+cu117- typing-extensions==4.5.0- urllib3==1.26.13- zipp==3.15.0
prefix: D:\ProgramData\Anaconda\Miniconda3\envs\SAM
最終的環(huán)境安裝所有的包與yaml文件一致。
# 查看所有安裝的包
pip list
conda list
SAM代碼使用
下載githup源碼以及所提供的權重文件
predictor_example
源碼在notebooks文件內(nèi)提供了一個Jupyter Notebook的使用教程,博主現(xiàn)在就以官方使用教程為模板,測試自己的數(shù)據(jù)集。
predictor_example.ipynb源碼在notebooks文件目錄下,可以本地運行測試或者直接在githup上查看教程。
步驟一:查看測試圖片
import cv2
import matplotlib.pyplot as plt
image = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('on')
plt.show()
步驟二:顯示前景和背景的標記點
import numpy as np
import matplotlib.pyplot as plt
import cv2def show_points(coords, labels, ax, marker_size=375):# 篩選出前景目標標記點pos_points = coords[labels == 1]# 篩選出背景目標標記點neg_points = coords[labels == 0]# x-->pos_points[:, 0] y-->pos_points[:, 1]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 前景的標記點顯示ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 背景的標記點顯示image = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 鼠標標定(x,y)位置
# 因為可以有多個標定,所以有多個坐標點
input_point = np.array([[230, 194], [182, 63], [339, 279]])
# 1表示前景目標,0表示背景
# input_point和input_label一一對應
input_label = np.array([1, 1, 0])plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
這里圖片可以用畫圖軟件打開查看像素坐標輔助標定
步驟三:標記點完成前景目標的分割
簡單的調(diào)用源碼模型,就能完成前景目標的分割,源碼提供了三中不同大小的模型,讀者可以自己去嘗試不同的模型效果。
博主在閱讀源碼后,后續(xù)會根據(jù)自己的理解講解源碼
import numpy as np
import matplotlib.pyplot as plt
import cv2def show_mask(mask, ax, random_color=False):if random_color: # 掩膜顏色是否隨機決定color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):# 篩選出前景目標標記點pos_points = coords[labels == 1]# 篩選出背景目標標記點neg_points = coords[labels == 0]# x-->pos_points[:, 0] y-->pos_points[:, 1]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 前景的標記點顯示ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 背景的標記點顯示
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictorimage = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#------加載模型
# 權重文件保存地址
sam_checkpoint = "model_save/sam_vit_b_01ec64.pth"
# sam_checkpoint = "model_save/sam_vit_h_4b8939.pth"
# sam_checkpoint = "model_save/sam_vit_l_0b3195.pth"
# 模型類型
model_type = "vit_b"
# model_type = "vit_h"
# model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
#------加載模型# 鼠標標定(x,y)位置
# 因為可以有多個標定,所以有多個坐標點
input_point = np.array([[230, 194], [182, 63], [339, 279]])
# 1表示前景目標,0表示背景
# input_point和input_label一一對應
input_label = np.array([1, 1, 0])masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=True,
)
for i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10, 10))plt.imshow(image)show_mask(mask, plt.gca())show_points(input_point, input_label, plt.gca())plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)plt.axis('off')plt.show()
這里會輸出三個結果
步驟四:標定框完成前景目標的分割
綠色的框是用戶自己標定的,根據(jù)框選的區(qū)域完成前景目標的分割。
import numpy as np
import matplotlib.pyplot as plt
import cv2def show_mask(mask, ax, random_color=False):if random_color: # 掩膜顏色是否隨機決定color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):# 篩選出前景目標標記點pos_points = coords[labels == 1]# 篩選出背景目標標記點neg_points = coords[labels == 0]# x-->pos_points[:, 0] y-->pos_points[:, 1]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 前景的標記點顯示ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 背景的標記點顯示
def show_box(box, ax):# 畫出標定框 x0 y0是起始坐標x0, y0 = box[0], box[1]# w h 是框的尺寸w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
image = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#------加載模型
# 權重文件保存地址
sam_checkpoint = "model_save/sam_vit_b_01ec64.pth"
# sam_checkpoint = "model_save/sam_vit_h_4b8939.pth"
# sam_checkpoint = "model_save/sam_vit_l_0b3195.pth"
# 模型類型
model_type = "vit_b"
# model_type = "vit_h"
# model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
#------加載模型# 標定框的起始坐標和終點坐標
input_box = np.array([112, 41, 373, 320])masks, _, _ = predictor.predict(point_coords=None,point_labels=None,box=input_box[None, :],multimask_output=False,
)plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
步驟五:標定框和標記點聯(lián)合完成前景目標的分割
對于一些復雜的目標,可能需要聯(lián)合使用提高前景目標的分割精度。
import numpy as np
import matplotlib.pyplot as plt
import cv2def show_mask(mask, ax, random_color=False):if random_color: # 掩膜顏色是否隨機決定color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):# 篩選出前景目標標記點pos_points = coords[labels == 1]# 篩選出背景目標標記點neg_points = coords[labels == 0]# x-->pos_points[:, 0] y-->pos_points[:, 1]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 前景的標記點顯示ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 背景的標記點顯示def show_box(box, ax):# 畫出標定框 x0 y0是起始坐標x0, y0 = box[0], box[1]# w h 是框的尺寸w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictorimage = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#------加載模型
# 權重文件保存地址
sam_checkpoint = "model_save/sam_vit_b_01ec64.pth"
# sam_checkpoint = "model_save/sam_vit_h_4b8939.pth"
# sam_checkpoint = "model_save/sam_vit_l_0b3195.pth"
# 模型類型
model_type = "vit_b"
# model_type = "vit_h"
# model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
#------加載模型# 標定框的起始坐標和終點坐標
input_box = np.array([112, 41, 373, 320])
# 鼠標標定(x,y)位置
# 因為可以有多個標定,所以有多個坐標點
input_point = np.array([[230, 194], [182, 63], [339, 279]])
# 1表示前景目標,0表示背景
# input_point和input_label一一對應
input_label = np.array([1, 1, 0])# 標定框和標記點聯(lián)合使用
masks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,box=input_box,multimask_output=False,
)plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
步驟六:多標定框完成前景目標的分割
可以是多標定框對應多個目標,也可以是多標定框對應同一目標的不同部位。
import numpy as np
import matplotlib.pyplot as plt
import torch
import cv2def show_mask(mask, ax, random_color=False):if random_color: # 掩膜顏色是否隨機決定color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):# 篩選出前景目標標記點pos_points = coords[labels == 1]# 篩選出背景目標標記點neg_points = coords[labels == 0]# x-->pos_points[:, 0] y-->pos_points[:, 1]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 前景的標記點顯示ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 背景的標記點顯示def show_box(box, ax):# 畫出標定框 x0 y0是起始坐標x0, y0 = box[0], box[1]# w h 是框的尺寸w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictorimage = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#------加載模型
# 權重文件保存地址
sam_checkpoint = "model_save/sam_vit_b_01ec64.pth"
# sam_checkpoint = "model_save/sam_vit_h_4b8939.pth"
# sam_checkpoint = "model_save/sam_vit_l_0b3195.pth"
# 模型類型
model_type = "vit_b"
# model_type = "vit_h"
# model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
#------加載模型# 存在多個目標標定框
input_boxes = torch.tensor([[121, 49, 361, 190],[143, 101, 308, 312],[366, 116, 451, 233],
], device=predictor.device)transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(point_coords=None,point_labels=None,boxes=transformed_boxes,multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
步驟六:圖片批量完成前景目標的分割
源碼支持圖片的批量輸入,大大提升了分割效率。
import numpy as np
import matplotlib.pyplot as plt
import torch
import cv2def show_mask(mask, ax, random_color=False):if random_color: # 掩膜顏色是否隨機決定color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):# 篩選出前景目標標記點pos_points = coords[labels == 1]# 篩選出背景目標標記點neg_points = coords[labels == 0]# x-->pos_points[:, 0] y-->pos_points[:, 1]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 前景的標記點顯示ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',linewidth=1.25) # 背景的標記點顯示def show_box(box, ax):# 畫出標定框 x0 y0是起始坐標x0, y0 = box[0], box[1]# w h 是框的尺寸w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))def prepare_image(image, transform, device):image = transform.apply_image(image)image = torch.as_tensor(image, device=device.device)return image.permute(2, 0, 1).contiguous()import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictorimage1 = cv2.imread('img.png')
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
image2 = cv2.imread('img_1.png')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)#------加載模型
# 權重文件保存地址
sam_checkpoint = "model_save/sam_vit_b_01ec64.pth"
# sam_checkpoint = "model_save/sam_vit_h_4b8939.pth"
# sam_checkpoint = "model_save/sam_vit_l_0b3195.pth"
# 模型類型
model_type = "vit_b"
# model_type = "vit_h"
# model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
#------加載模型# 存在多個目標標定框
image1_boxes = torch.tensor([[121, 49, 361, 190],[143, 101, 308, 312],[366, 116, 451, 233],
], device=sam.device)image2_boxes = torch.tensor([[24, 4, 333, 265],
], device=sam.device)# 批量輸入
batched_input = [{'image': prepare_image(image1, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),'original_size': image1.shape[:2]},{'image': prepare_image(image2, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),'original_size': image2.shape[:2]}
]
batched_output = sam(batched_input, multimask_output=False)fig, ax = plt.subplots(1, 2, figsize=(20, 20))# 批量輸出
ax[0].imshow(image1)
for mask in batched_output[0]['masks']:show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')ax[1].imshow(image2)
for mask in batched_output[1]['masks']:show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')
plt.tight_layout()
plt.show()
automatic_mask_generator_example
源碼在notebooks文件內(nèi)提供了一個Jupyter Notebook的自動分割教程,無需標定點和標定框的。
automatic_mask_generator_example.ipynb源碼在notebooks文件目錄下,可以本地運行測試或者直接在githup上查看教程。
步驟一:自動掩碼生成
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2image = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 權重文件保存地址
sam_checkpoint = "model_save/sam_vit_b_01ec64.pth"
# sam_checkpoint = "model_save/sam_vit_h_4b8939.pth"
# sam_checkpoint = "model_save/sam_vit_l_0b3195.pth"
# 模型類型
model_type = "vit_b"
# model_type = "vit_h"
# model_type = "vit_l"
device = "cuda"def show_anns(anns):if len(anns) == 0:returnsorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)ax = plt.gca()ax.set_autoscale_on(False)polygons = []color = []for ann in sorted_anns:m = ann['segmentation']img = np.ones((m.shape[0], m.shape[1], 3))color_mask = np.random.random((1, 3)).tolist()[0] # 產(chǎn)生隨機顏色的maskfor i in range(3):img[:, :, i] = color_mask[i]ax.imshow(np.dstack((img, m*0.35)))from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
步驟一:自動掩碼生成參數(shù)調(diào)整
在自動掩模生成中有幾個可調(diào)參數(shù),用于控制采樣點的密度以及去除低質(zhì)量或重復掩模的閾值。此外,生成可以在圖像的裁剪上自動運行,以提高較小對象的性能,后處理可以去除雜散像素和孔洞。
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2image = cv2.imread('img.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 權重文件保存地址
sam_checkpoint = "model_save/sam_vit_b_01ec64.pth"
# sam_checkpoint = "model_save/sam_vit_h_4b8939.pth"
# sam_checkpoint = "model_save/sam_vit_l_0b3195.pth"
# 模型類型
model_type = "vit_b"
# model_type = "vit_h"
# model_type = "vit_l"
device = "cuda"def show_anns(anns):if len(anns) == 0:returnsorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)ax = plt.gca()ax.set_autoscale_on(False)polygons = []color = []for ann in sorted_anns:m = ann['segmentation']img = np.ones((m.shape[0], m.shape[1], 3))color_mask = np.random.random((1, 3)).tolist()[0] # 產(chǎn)生隨機顏色的maskfor i in range(3):img[:, :, i] = color_mask[i]ax.imshow(np.dstack((img, m*0.35)))from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)# 默認版本
# mask_generator = SamAutomaticMaskGenerator(sam)
# 自定義參數(shù)版本
mask_generator_2 = SamAutomaticMaskGenerator(model=sam,points_per_side=32,pred_iou_thresh=0.86,stability_score_thresh=0.92,crop_n_layers=1,crop_n_points_downscale_factor=2,min_mask_region_area=100, # Requires open-cv to run post-processing
)masks = mask_generator_2.generate(image)plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
總結
盡可能簡單、詳細的介紹了SAM的安裝流程以及SAM官方的基本使用方法。后續(xù)會根據(jù)自己學到的知識結合個人理解講解SAM的原理和代碼,目前只是拙劣的使用。