昇思25天学习打卡营第11天 | FCN图像语义分割

昇思25天学习打卡营第11天 | FCN图像语义分割

文章目录

  • 昇思25天学习打卡营第11天 | FCN图像语义分割
    • FCN模型
    • 数据处理
      • 下载数据集
      • 创建训练集
      • 可视化训练集
    • 网络构建
      • 网络结构
      • 张量操作
    • 训练准备
      • 导入VGG-16部分预训练权重:
      • 损失函数
      • 模型评估指标
    • 模型训练
    • 模型评估
    • 模型推理
    • 总结
    • 打卡

语义分割(semantic segmentation) 常被用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶等领域。
语义分割的目的是对图像中的每一个像素进行分类,输出与输入图像大小相同的图像,每个像素代表对应输入像素所属的类别。

fcn-2

FCN模型

全卷积神经网络(Fully Convolutional Networks,FCN)是一种端到端的分割方法,通过进行像素级的预测直接得出与原图大小相等的label map。

FCN主要使用以下三种技术:

  1. 卷积化: 使用VGG-16作为FCN的backbone。
    fcn-3
  2. 上采样: :通过反卷积对特征图进行上采样,以恢复输入图像的分辨率。
    fcn-4
  3. 跳跃结构: 由于最后一层特征图太小,损失过多细节,采用skips结构将更具全局信息的最后一层和更浅层的预测结合,使预测结果获得更多的局部细节。
    fcn-5

数据处理

下载数据集

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"

download(url, "./dataset", kind="tar", replace=True)

创建训练集

import cv2
import mindspore.dataset as ds
import numpy as np


class SegDataset:
    def __init__(
        self,
        image_mean,
        image_std,
        data_file="",
        batch_size=32,
        crop_size=512,
        max_scale=2.0,
        min_scale=0.5,
        ignore_label=255,
        num_classes=21,
        num_readers=2,
        num_parallel_calls=4,
    ):

        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale

    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(
            np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE
        )
        sc = np.random.uniform(self.min_scale, self.max_scale)  # 随机缩放
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(
            label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST
        )

        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(
                image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0
            )
            label_out = cv2.copyMakeBorder(
                label_out,
                0,
                pad_h,
                0,
                pad_w,
                cv2.BORDER_CONSTANT,
                value=self.ignore_label,
            )
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[
            offset_h : offset_h + self.crop_size,
            offset_w : offset_w + self.crop_size,
            :,
        ]
        label_out = label_out[
            offset_h : offset_h + self.crop_size, offset_w : offset_w + self.crop_size
        ]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out

    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(
            self.data_file,
            columns_list=["data", "label"],
            shuffle=True,
            num_parallel_workers=self.num_readers,
        )
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(
            operations=transforms_list,
            input_columns=["data", "label"],
            output_columns=["data", "label"],
            num_parallel_workers=self.num_parallel_calls,
        )
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset


# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

# 实例化Dataset
dataset = SegDataset(
    image_mean=IMAGE_MEAN,
    image_std=IMAGE_STD,
    data_file=DATA_FILE,
    batch_size=train_batch_size,
    crop_size=crop_size,
    max_scale=max_scale,
    min_scale=min_scale,
    ignore_label=ignore_label,
    num_classes=num_classes,
    num_readers=2,
    num_parallel_calls=4,
)

dataset = dataset.get_dataset()

在上面的类SegDataset中:

  • __init__()里初始化了一些参数
  • preprocess_dataset()定义了对输入图像的大量变换,包括:
    • 对图像和标签随机缩放;
    • 图像归一化;
    • 填充或裁剪图片为crop_size × crop_size大小;
    • 随机水平翻转;
    • 将<H,W,C>图像转换为<C,H,W>;
    • 将标签数据转换为int32类型。
  • get_dataset()中定义了数据集和数据变换,设置了并行。

可视化训练集

import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(16, 8))

# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
    # 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

网络构建

fcn-6

网络结构

  • 第一个卷积块:输入 512 × 512 × 3 512\times512\times3 512×512×3图像,每一个卷积操作后面紧跟一个nn.BatchNorm2d(out_channels)nn.ReLU()
self.conv1 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=3,
                out_channels=64,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=64,
                out_channels=64,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
  • 第一个池化:最大池化(Max Pooling),输入图像尺寸变为原始图像 1 / 2 1/2 1/2
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第二个卷积块:
self.conv2 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=128,
                out_channels=128,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
  • 第二个池化:变为原始图像 1 / 4 1/4 1/4
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第三个卷积快:
 self.conv3 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=128,
                out_channels=256,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=256,
                out_channels=256,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=256,
                out_channels=256,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
  • 第三个池化:变为原始图像 1 / 8 1/8 1/8
 self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第四个卷积块:
 self.conv4 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=256,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
  • 第四个池化:变为原始图像 1 / 16 1/16 1/16
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第五个卷积块:
 self.conv5 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=512,
                out_channels=512,
                kernel_size=3,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
  • 第五个池化:变为原始图像 1 / 32 1/32 1/32
 self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第六、七个卷积块:保持图像大小不变,替换全连接层
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=512,
                out_channels=4096,
                kernel_size=7,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(
                in_channels=4096,
                out_channels=4096,
                kernel_size=1,
                weight_init="xavier_uniform",
            ),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
                          kernel_size=1, weight_init='xavier_uniform')
  • 反卷积层:
        self.upscore2 = nn.Conv2dTranspose(
            in_channels=self.n_class,
            out_channels=self.n_class,
            kernel_size=4,
            stride=2,
            weight_init="xavier_uniform",
        )
        self.score_pool4 = nn.Conv2d(
            in_channels=512,
            out_channels=self.n_class,
            kernel_size=1,
            weight_init="xavier_uniform",
        )
        self.upscore_pool4 = nn.Conv2dTranspose(
            in_channels=self.n_class,
            out_channels=self.n_class,
            kernel_size=4,
            stride=2,
            weight_init="xavier_uniform",
        )
        self.score_pool3 = nn.Conv2d(
            in_channels=256,
            out_channels=self.n_class,
            kernel_size=1,
            weight_init="xavier_uniform",
        )
        self.upscore8 = nn.Conv2dTranspose(
            in_channels=self.n_class,
            out_channels=self.n_class,
            kernel_size=16,
            stride=8,
            weight_init="xavier_uniform",
        )

张量操作

def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out

训练准备

导入VGG-16部分预训练权重:

from download import download
from mindspore import load_checkpoint, load_param_into_net

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)


def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)

损失函数

语义分割是对图像中像素点进行分类,仍然属于分类问题,故使用交叉熵损失函数nn.CrossEntropyLoss()

模型评估指标

用来评估训练出来的模型好坏。

  • Pixel Accuracy(PA, 像素精度):标记真确的像素占总像素的比例:
    P A = ∑ i = 0 k p i i ∑ i = 0 k ∑ j = 0 k p i j PA = \frac{\sum_{i=0}^kp_{ii}}{\sum_{i=0}^k\sum_{j=0}^kp_{ij}} PA=i=0kj=0kpiji=0kpii
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train
import numpy as np


class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype("int") + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def eval(self):
        pixel_accuracy = (
            np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        )
        return pixel_accuracy
  • Mean Pixel Accuracy(MPA,均像素精度):计算每个类内正确分类的像素比例,然后求平均:
    M P A = 1 k + 1 ∑ i = 0 k p i i ∑ j = 0 k p i j MPA = \frac1{k+1}\sum_{i=0}^k\frac{p_{ii}}{\sum_{j=0}^kp_{ij}} MPA=k+11i=0kj=0kpijpii
class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype("int") + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_pixel_accuracy = np.diag(
            self.confusion_matrix
        ) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy
  • Mean Intersction over Union(MIoU,均交并比):两个集合(真实值和预测值)的交集和并集的比值。
    M I o U = 1 k + 1 ∑ i = 0 k p i i ∑ j = 0 k p i j + ∑ j = 0 k p j i − p i i MIoU=\frac1{k+1}\sum_{i=0}^k\frac{p_{ii}}{\sum_{j=0}^kp_{ij}+\sum_{j=0}^kp_{ji}-p_{ii}} MIoU=k+11i=0kj=0kpij+j=0kpjipiipii
class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype("int") + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1)
            + np.sum(self.confusion_matrix, axis=0)
            - np.diag(self.confusion_matrix)
        )
        mean_iou = np.nanmean(mean_iou)
        return mean_iou
  • Frequency Weighted Intersection over Union(FWIoU,频权交并比):根据每个类别出现的频率为MIoU设置权重
    F W I o U = 1 ∑ i = 0 k ∑ j = 0 k p i j ∑ i = 0 k p i i ∑ j = 0 k p i j + ∑ j = 0 k p j i − p i i FWIoU=\frac1{\sum_{i=0}^k\sum_{j=0}^kp_{ij}}\sum_{i=0}^k\frac{p_{ii}}{\sum_{j=0}^kp_{ij}+\sum_{j=0}^kp_{ji}-p_{ii}} FWIoU=i=0kj=0kpij1i=0kj=0kpij+j=0kpjipiipii
class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype("int") + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1)
            + np.sum(self.confusion_matrix, axis=0)
            - np.diag(self.confusion_matrix)
        )

        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou

模型训练

导入VGG-16预训练参数,定义超参数,实例化损失函数、优化器,使用Model接口编译网络,然后开始训练:

import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.train import (
    CheckpointConfig,
    LossMonitor,
    Model,
    ModelCheckpoint,
    TimeMonitor,
)

device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)

train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

lr_scheduler = mindspore.nn.cosine_decay_lr(
    min_lr, base_lr, total_step, iters_per_epoch, decay_epoch=2
)
lr = Tensor(lr_scheduler[-1])

# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(
    params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001
)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":
    model = Model(
        net,
        loss_fn=loss,
        optimizer=optimizer,
        loss_scale_manager=loss_scale_manager,
        metrics={
            "pixel accuracy": PixelAccuracy(),
            "mean pixel accuracy": PixelAccuracyClass(),
            "mean IoU": MeanIntersectionOverUnion(),
            "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(),
        },
    )
else:
    model = Model(
        net,
        loss_fn=loss,
        optimizer=optimizer,
        metrics={
            "pixel accuracy": PixelAccuracy(),
            "mean pixel accuracy": PixelAccuracyClass(),
            "mean IoU": MeanIntersectionOverUnion(),
            "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(),
        },
    )

# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(
    save_checkpoint_steps=10, keep_checkpoint_max=keep_checkpoint_max
)
ckpt_callback = ModelCheckpoint(prefix="FCN8s", directory="./ckpt", config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)

模型评估

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)

ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

if device_target == "Ascend":
    model = Model(
        net,
        loss_fn=loss,
        optimizer=optimizer,
        loss_scale_manager=loss_scale_manager,
        metrics={
            "pixel accuracy": PixelAccuracy(),
            "mean pixel accuracy": PixelAccuracyClass(),
            "mean IoU": MeanIntersectionOverUnion(),
            "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(),
        },
    )
else:
    model = Model(
        net,
        loss_fn=loss,
        optimizer=optimizer,
        metrics={
            "pixel accuracy": PixelAccuracy(),
            "mean pixel accuracy": PixelAccuracyClass(),
            "mean IoU": MeanIntersectionOverUnion(),
            "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(),
        },
    )

# 实例化Dataset
dataset = SegDataset(
    image_mean=IMAGE_MEAN,
    image_std=IMAGE_STD,
    data_file=DATA_FILE,
    batch_size=train_batch_size,
    crop_size=crop_size,
    max_scale=max_scale,
    min_scale=min_scale,
    ignore_label=ignore_label,
    num_classes=num_classes,
    num_readers=2,
    num_parallel_calls=4,
)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)

模型推理

使用训练的网络进行推理:

import cv2
import matplotlib.pyplot as plt

net = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):
    plt.subplot(2, 4, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 4, i + 5)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

总结

这一节介绍了图像语义分割中的FCN网络,从网络的结构开始,介绍了图像数据集的变换和创建,介绍了完整的网络框架和Tensor操作,以及预训练模型的加载。此外,对于模型效果的而评估,还介绍了4个指标。通过这一节,大致学会了从paper的网络结构中创建一个网络模型。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/793886.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

「C++系列」一篇文章说透【存储类】

文章目录 一、C 存储类1. 类的定义2. 对象的创建3. 对象在内存中的布局4. 对象的存储位置 二、auto 存储类1. auto的基本用法2. auto与存储类的关系1) 自动存储类&#xff08;最常见的&#xff09;2) 静态存储类3) 动态存储类&#xff08;通过new&#xff09; 三、register 存储…

SSRF漏洞深入利用与防御方案绕过技巧

文章目录 前言SSRF基础利用1.1 http://内网资源访问1.2 file:///读取内网文件1.3 dict://探测内网端口 SSRF进阶利用2.1 Gopher协议Post请求2.2 Gopher协议文件上传2.3 GopherRedis->RCE2.4 JavaWeb中的适用性&#xff1f; SSRF防御绕过3.1 Url黑名单检测的绕过3.2 Url白名单…

【PHP安装内置扩展】

PHP安装内置扩展 1、首先查看php源码以及查询是否有需要的扩展;本次以zlib扩展为例子 2、进入需要安装的扩展目录,执行命令 cd zlib 执行 make clean 清掉之前的安装的残留文件; 不需要的话直接略过,新安装也略过3、运行phpize,执行/usr/local/php/bin/phpize 注意这个路径一…

【BES2500x系列 -- RTX5操作系统】深入探索CMSIS-RTOS RTX -- 配置篇 -- 初识GPIO --(六)

&#x1f48c; 所属专栏&#xff1a;【BES2500x系列】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f49…

子任务:IT运维的精细化管理之道

在当今的企业运营中&#xff0c;信息技术已成为支撑业务发展的核心力量。根据Gartner的报告&#xff0c;IT服务管理&#xff08;ITSM&#xff09;的有效实施可以显著提升企业的运营效率&#xff0c;降低成本高达15%&#xff0c;同时提高服务交付速度和质量。随着业务的复杂性和…

7个外贸网站模板

Nebula独立站wordpress主题 Nebula奈卜尤拉wordpress主题模板&#xff0c;适合搭建外贸独立站使用的wordpress主题。 https://www.jianzhanpress.com/?p7084 Starling师大林WordPress独立站模板 蓝色橙色风格的WordPress独立站模板&#xff0c;适合做对外贸易的外贸公司搭建…

pytorch 是如何调用 cusolver API 的调用

0&#xff0c;环境 ubuntu 22.04 pytorch 2.3.1 x86 RTX 3080 cuda 12.2 1, 示例代码 以potrs为例&#xff1b; hello_cholesk.py """ hello_cholesky.py step1, Cholesky decompose; step2, inverse A; step3, Cholesky again; python3 hello_cholesky.py -…

Hi3861 OpenHarmony嵌入式应用入门--华为 IoTDA 设备接入

华为云物联网平台&#xff08;IoT 设备接入云服务&#xff09;提供海量设备的接入和管理能力&#xff0c;可以将自己的 IoT 设备 联接到华为云&#xff0c;支撑设备数据采集上云和云端下发命令给设备进行远程控制&#xff0c;配合华为云物联网平台的服 务实现设备与设备之间的控…

【区块链农场】:农场游戏+游戏

我的酒坊是一款非常受玩家欢迎的经营手游,游戏中你需要合理经营一家酒厂,将其做大做强。通过制定合理的战略,例如新建厂房,并采用传统工艺制作,针对不同的人群研制多重口味。

配置sublime的中的C++编译器(.sublime-build),实现C++20在sublime中的使用,小白教程

一&#xff0c;前期准备 首先我们需要准备一下 C 环境&#xff0c;当然如果你觉得你当前的C环境配置好了&#xff0c;并且C的版本也能完成您日常的使用需求&#xff0c;您可以使用下面两条命令对C的版本进行查询 g -vg --version通过返回的版本简单的判断是否能解决您的需求&…

[Godot3.3.3] - 过渡动画

过渡动画 ScreenTransitionAnimation 项目结构 添加场景&#xff0c;根节点为 CanvasLayer2D 并重命名为 ScreenTransition: 添加子节点 ColorRect 和 AnimationPlayer&#xff0c;在 ColorRect 中将颜色(Color)设置为黑色&#xff1a; 找到 Material&#xff0c;新建 Shader…

【Element-UI 表格表头、内容合并单元格】

一、实现效果&#xff1a; &#x1f970; 表头合并行、合并列 &#x1f970; &#x1f970; 表格内容行、合并列 &#x1f970; thead和tbody分别有单独的合并方法 二、关键代码&#xff1a; <el-table size"mini" class"table-th-F4F6FB" align&qu…

初识c++(类与对象——上)

一、类的定义 1、类定义格式 • class为定义类的关键字&#xff0c;Stack为类的名字&#xff0c;{}中为类的主体&#xff0c;注意类定义结束时后面分号不能省 略。类体中内容称为类的成员&#xff1a;类中的变量称为类的属性或成员变量; 类中的函数称为类的方法或 者成员函…

【C++初阶】类和对象(上)

【C初阶】类和对象&#xff08;上&#xff09; &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;C&#x1f96d; &#x1f33c;文章目录&#x1f33c; 1. 面向过程和面向对象初步认识 2. 类的引入 3. 类的定义 4. 类的访问限定符及封…

python 实验八 数据分析与展示

一、实验目的 掌握掌握matplotlib库中pyplot模块的使用。 二、实验环境 Window10&#xff08;x64&#xff09;&#xff0c;Python 3.8&#xff08;x64&#xff09;&#xff0c;PyCharm Community Edition 2020.3.2&#xff08;x64&#xff09; 三、实验内容 现有列表hight…

jmeter分布式(四)

一、gui jmeter的gui主要用来调试脚本 1、先gui创建脚本 先做一个脚本 演示&#xff1a;如何做混合场景的脚本&#xff1f; 用211的业务比例 ①启动数据库服务 数据库服务&#xff1a;包括mysql、redis mysql端口默认3306 netstat -lntp | grep 3306处于监听状态&#xf…

[迫真保姆级教程]在Windows上编译可用的Tesseract OCR in C++ 并部署在Visual Studio与Qt6上

目录 前言 阅前提示 导言 使用基于vcpkg的&#xff0c;于msvc19编译器编译的Tessereact OCR动态库 使用vcpkg辅助我们的编译 正文 使用msys2环境下的&#xff0c;使用mingw64编译器编译的Tessereact OCR动态库 什么是msys2 安装前&#xff0c;我们也许。。。 [Option]…

python作业二

# 二进制转化为十进制 num input("num:")def binaryToDecimal(binaryString):he 0length len(binaryString)for i in range(length):he int(binaryString[i]) * 2 ** (length - i - 1)return heprint(binaryToDecimal(num))代码运行如下&#xff1a; import math…

基于YOLOV8的数粒机-农业应用辣椒种子计数计重双标质量解决方案

一:辣椒种子行业背景调查 中国辣椒年产量稳居世界第一,食辣人口超5亿。中国辣椒全球闻名,小辣椒长成大产业,带动全球食品行业腾飞。 在中国,“辣”是不少地方餐桌上的一大特色。从四川的麻辣火锅到湖南的剁椒鱼头再到陕西的油泼辣子面,由南到北,总有食客对辣有着独一份偏…

【RHCE】系统服务综合实验

一、实验内容 现有主机 node01 和 node02&#xff0c;完成如下需求&#xff1a; 1、在 node01 主机上提供 DNS 和 WEB 服务 2、dns 服务提供本实验所有主机名解析 3、web服务提供 www.rhce.com 虚拟主机 4、该虚拟主机的documentroot目录在 /nfs/rhce 目录 5、该目录由 node02…