前两天进行一个分割项目,模型训练完成后需要对分割效果可视化,特此记录两种方法。
示例中用到的原图、label和mask如下所示:
1. 标签label填充,预测mask勾线
第一种是标签label填充样覆盖,预测mask采用勾线描边的方式,勾线描边利用 c v 2 {cv2} cv2 中的 c v 2. t h r e s h o l d {cv2.threshold} cv2.threshold、 c v 2. f i n d C o n t o u r s {cv2.findContours} cv2.findContours 和 c v 2. d r a w C o n t o u r s {cv2.drawContours} cv2.drawContours 函数。
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision.transforms import transforms
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.5])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
return mask_image
image = cv2.imread('./data/0.png')
label = Image.open("./data/0_gt.jpg")
mask = Image.open("./data/0_pred.jpg")
h, w, _ = image.shape
label = transforms.Resize((h,w))(label)
label = np.uint8(label)
label = 255 * (label > 127)
label = label == 255
mask = transforms.Resize((h,w))(mask)
mask = np.uint8(mask)
ret_mask, thresh_mask = cv2.threshold(mask, 127, 255, 0)
contours_mask, hierarchy_mask = cv2.findContours(thresh_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for cnt_mask in contours_mask:
cv2.drawContours(image, [cnt_mask], 0, (0,255,0), 2)
plt.figure(dpi=300)
plt.imshow(image)
show_mask(label, plt.gca())
plt.axis('off')
plt.savefig("output_image.jpg")
效果图:
2. 标签label勾线,预测mask勾线
第二种是利用不用的颜色均采用勾线描边的方式, c v 2. d r a w C o n t o u r s {cv2.drawContours} cv2.drawContours 自带 hold on 功能,只需两个循环就搞定啦~
import cv2
import numpy as np
from PIL import Image, ImageDraw
from torchvision.transforms import transforms
image = cv2.imread('./data/0.png')
label = Image.open("./data/0_gt.jpg")
mask = Image.open("./data/0_pred.jpg")
h, w, _ = image.shape
label = transforms.Resize((h,w))(label)
label = np.uint8(label)
mask = transforms.Resize((h,w))(mask)
mask = np.uint8(mask)
ret_label, thresh_label = cv2.threshold(label, 127, 255, 0)
contours_label, hierarchy_label = cv2.findContours(thresh_label, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
ret_mask, thresh_mask = cv2.threshold(mask, 127, 255, 0)
contours_mask, hierarchy_mask = cv2.findContours(thresh_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for cnt_label in contours_label:
cv2.drawContours(image, [cnt_label], 0, (0,0,255), 2)
for cnt_mask in contours_mask:
cv2.drawContours(image, [cnt_mask], 0, (0,255,0), 2)
cv2.imwrite('output_image.jpg', image)
效果图: