政安晨:【Keras机器学习示例演绎】(八)—— 利用 PointNet 进行点云分割

目录

简介

导入

下载数据集

加载数据集

构建数据集

预处理

创建 TensorFlow 数据集

PointNet 模型

排列不变性

变换不变性

点之间的相互作用

实例化模型

训练

直观了解培训情况

推论

最后说明


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实现基于点网的点云分割模型。

简介


点云 "是存储几何形状数据的一种重要数据结构类型。由于其格式不规则,在用于深度学习应用之前,通常要将其转换为规则的三维体素网格或图像集合,这一步骤会使数据变得不必要的庞大。PointNet 系列模型通过直接消耗点云解决了这一问题,同时尊重点数据的包络不变性属性。PointNet 系列模型提供了一个简单、统一的架构,适用于从物体分类、部件分割到场景语义解析等各种应用。

在本示例中,我们演示了用于形状分割的 PointNet 架构的实施。

导入

import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import glob

import tensorflow as tf  # For tf.data
import keras
from keras import layers

import matplotlib.pyplot as plt

下载数据集


ShapeNet 数据集是为建立一个注释丰富的大规模三维图形数据集而持续开展的一项工作。ShapeNetCore 是完整 ShapeNet 数据集的一个子集,其中包含干净的单个三维模型以及人工验证的类别和排列注释。它涵盖 55 个常见物体类别,拥有约 51,300 个独特的三维模型。

在这个例子中,我们使用了 PASCAL 3D+ 的 12 个对象类别之一,它是 ShapenetCore 数据集的一部分。

dataset_url = "https://git.io/JiY4i"

dataset_path = keras.utils.get_file(
    fname="shapenet.zip",
    origin=dataset_url,
    cache_subdir="datasets",
    hash_algorithm="auto",
    extract=True,
    archive_format="auto",
    cache_dir="datasets",
)

加载数据集


我们对数据集元数据进行解析,以便轻松地将模型类别映射到各自的目录中,并将分割类别映射到颜色中,从而实现可视化。

with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:
    metadata = json.load(json_file)

print(metadata)
{'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '03642806', 'lables': ['keyboard'], 'colors': ['blue']}, 'Motorbike': {'directory': '03790512', 'lables': ['wheel', 'handle', 'gas_tank', 'light', 'seat'], 'colors': ['blue', 'green', 'red', 'pink', 'yellow']}, 'Mug': {'directory': '03797390', 'lables': ['handle'], 'colors': ['blue']}, 'Pistol': {'directory': '03948459', 'lables': ['trigger_and_guard', 'handle', 'barrel'], 'colors': ['blue', 'green', 'red']}, 'Rocket': {'directory': '04099429', 'lables': ['nose', 'body', 'fin'], 'colors': ['blue', 'green', 'red']}, 'Skateboard': {'directory': '04225987', 'lables': ['wheel', 'deck'], 'colors': ['blue', 'green']}, 'Table': {'directory': '04379243', 'lables': ['leg', 'top'], 'colors': ['blue', 'green']}}

在本例中,我们训练 PointNet 对飞机模型的部件进行分割。

points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(
    metadata["Airplane"]["directory"]
)
labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(
    metadata["Airplane"]["directory"]
)
LABELS = metadata["Airplane"]["lables"]
COLORS = metadata["Airplane"]["colors"]

VAL_SPLIT = 0.2
NUM_SAMPLE_POINTS = 1024
BATCH_SIZE = 32
EPOCHS = 60
INITIAL_LR = 1e-3

构建数据集


我们根据飞机点云及其标签生成以下内存数据结构:

× point_clouds 是一个 np.array 对象列表,以 x、y 和 z 坐标的形式表示点云数据。轴 0 表示点云中的点数,轴 1 表示坐标。all_labels 是一个列表,以字符串形式表示每个坐标的标签(主要用于可视化目的)。
× test_point_clouds 与 point_clouds 格式相同,但没有相应的点云标签。
× all_labels 是一个 np.array 对象列表,表示每个坐标的点云标签,与 point_clouds 列表相对应。
× point_cloud_labels 是一个 np.array 对象列表,表示每个坐标的点云标签,以单击编码的形式表示,与 point_clouds 列表相对应。

point_clouds, test_point_clouds = [], []
point_cloud_labels, all_labels = [], []

points_files = glob(os.path.join(points_dir, "*.pts"))
for point_file in tqdm(points_files):
    point_cloud = np.loadtxt(point_file)
    if point_cloud.shape[0] < NUM_SAMPLE_POINTS:
        continue

    # Get the file-id of the current point cloud for parsing its
    # labels.
    file_id = point_file.split("/")[-1].split(".")[0]
    label_data, num_labels = {}, 0
    for label in LABELS:
        label_file = os.path.join(labels_dir, label, file_id + ".seg")
        if os.path.exists(label_file):
            label_data[label] = np.loadtxt(label_file).astype("float32")
            num_labels = len(label_data[label])

    # Point clouds having labels will be our training samples.
    try:
        label_map = ["none"] * num_labels
        for label in LABELS:
            for i, data in enumerate(label_data[label]):
                label_map[i] = label if data == 1 else label_map[i]
        label_data = [
            LABELS.index(label) if label != "none" else len(LABELS)
            for label in label_map
        ]
        # Apply one-hot encoding to the dense label representation.
        label_data = keras.utils.to_categorical(label_data, num_classes=len(LABELS) + 1)

        point_clouds.append(point_cloud)
        point_cloud_labels.append(label_data)
        all_labels.append(label_map)
    except KeyError:
        test_point_clouds.append(point_cloud)
100%|██████████████████████████████████████████████████████████████████████| 4045/4045 [01:30<00:00, 44.54it/s]

接下来,我们看看刚刚生成的内存阵列中的一些样本:

for _ in range(5):
    i = random.randint(0, len(point_clouds) - 1)
    print(f"point_clouds[{i}].shape:", point_clouds[0].shape)
    print(f"point_cloud_labels[{i}].shape:", point_cloud_labels[0].shape)
    for j in range(5):
        print(
            f"all_labels[{i}][{j}]:",
            all_labels[i][j],
            f"\tpoint_cloud_labels[{i}][{j}]:",
            point_cloud_labels[i][j],
            "\n",
        )

演绎展示:
 

point_clouds[333].shape: (2571, 3)
point_cloud_labels[333].shape: (2571, 5)
all_labels[333][0]: tail    point_cloud_labels[333][0]: [0. 0. 1. 0. 0.] 
all_labels[333][1]: wing    point_cloud_labels[333][1]: [1. 0. 0. 0. 0.] 
all_labels[333][2]: tail    point_cloud_labels[333][2]: [0. 0. 1. 0. 0.] 
all_labels[333][3]: engine  point_cloud_labels[333][3]: [0. 0. 0. 1. 0.] 
all_labels[333][4]: wing    point_cloud_labels[333][4]: [1. 0. 0. 0. 0.] 
point_clouds[3273].shape: (2571, 3)
point_cloud_labels[3273].shape: (2571, 5)
all_labels[3273][0]: body   point_cloud_labels[3273][0]: [0. 1. 0. 0. 0.] 
all_labels[3273][1]: body   point_cloud_labels[3273][1]: [0. 1. 0. 0. 0.] 
all_labels[3273][2]: tail   point_cloud_labels[3273][2]: [0. 0. 1. 0. 0.] 
all_labels[3273][3]: wing   point_cloud_labels[3273][3]: [1. 0. 0. 0. 0.] 
all_labels[3273][4]: wing   point_cloud_labels[3273][4]: [1. 0. 0. 0. 0.] 
point_clouds[929].shape: (2571, 3)
point_cloud_labels[929].shape: (2571, 5)
all_labels[929][0]: body    point_cloud_labels[929][0]: [0. 1. 0. 0. 0.] 
all_labels[929][1]: tail    point_cloud_labels[929][1]: [0. 0. 1. 0. 0.] 
all_labels[929][2]: wing    point_cloud_labels[929][2]: [1. 0. 0. 0. 0.] 
all_labels[929][3]: tail    point_cloud_labels[929][3]: [0. 0. 1. 0. 0.] 
all_labels[929][4]: body    point_cloud_labels[929][4]: [0. 1. 0. 0. 0.] 
point_clouds[496].shape: (2571, 3)
point_cloud_labels[496].shape: (2571, 5)
all_labels[496][0]: body    point_cloud_labels[496][0]: [0. 1. 0. 0. 0.] 
all_labels[496][1]: body    point_cloud_labels[496][1]: [0. 1. 0. 0. 0.] 
all_labels[496][2]: body    point_cloud_labels[496][2]: [0. 1. 0. 0. 0.] 
all_labels[496][3]: wing    point_cloud_labels[496][3]: [1. 0. 0. 0. 0.] 
all_labels[496][4]: body    point_cloud_labels[496][4]: [0. 1. 0. 0. 0.] 
point_clouds[3508].shape: (2571, 3)
point_cloud_labels[3508].shape: (2571, 5)
all_labels[3508][0]: body   point_cloud_labels[3508][0]: [0. 1. 0. 0. 0.] 
all_labels[3508][1]: body   point_cloud_labels[3508][1]: [0. 1. 0. 0. 0.] 
all_labels[3508][2]: body   point_cloud_labels[3508][2]: [0. 1. 0. 0. 0.] 
all_labels[3508][3]: body   point_cloud_labels[3508][3]: [0. 1. 0. 0. 0.] 
all_labels[3508][4]: body   point_cloud_labels[3508][4]: [0. 1. 0. 0. 0.] 

现在,让我们将一些点云及其标签可视化。

def visualize_data(point_cloud, labels):
    df = pd.DataFrame(
        data={
            "x": point_cloud[:, 0],
            "y": point_cloud[:, 1],
            "z": point_cloud[:, 2],
            "label": labels,
        }
    )
    fig = plt.figure(figsize=(15, 10))
    ax = plt.axes(projection="3d")
    for index, label in enumerate(LABELS):
        c_df = df[df["label"] == label]
        try:
            ax.scatter(
                c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
            )
        except IndexError:
            pass
    ax.legend()
    plt.show()


visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])

预处理


需要注意的是,我们加载的所有点云都由数量不等的点组成,这使得我们很难将它们集中在一起。为了解决这个问题,我们从每个点云中随机抽取固定数量的点。我们还对点云进行了归一化处理,以使数据与比例尺无关。

for index in tqdm(range(len(point_clouds))):
    current_point_cloud = point_clouds[index]
    current_label_cloud = point_cloud_labels[index]
    current_labels = all_labels[index]
    num_points = len(current_point_cloud)
    # Randomly sampling respective indices.
    sampled_indices = random.sample(list(range(num_points)), NUM_SAMPLE_POINTS)
    # Sampling points corresponding to sampled indices.
    sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])
    # Sampling corresponding one-hot encoded labels.
    sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])
    # Sampling corresponding labels for visualization.
    sampled_labels = np.array([current_labels[i] for i in sampled_indices])
    # Normalizing sampled point cloud.
    norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)
    norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))
    point_clouds[index] = norm_point_cloud
    point_cloud_labels[index] = sampled_label_cloud
    all_labels[index] = sampled_labels
100%|█████████████████████████████████████████████████████████████████████| 3694/3694 [00:08<00:00, 446.45it/s]

让我们将采样和归一化的点云及其相应的标签可视化。

visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])

创建 TensorFlow 数据集


我们为训练数据和验证数据创建 tf.data.Dataset 对象。我们还通过随机抖动来增强训练点云。

def load_data(point_cloud_batch, label_cloud_batch):
    point_cloud_batch.set_shape([NUM_SAMPLE_POINTS, 3])
    label_cloud_batch.set_shape([NUM_SAMPLE_POINTS, len(LABELS) + 1])
    return point_cloud_batch, label_cloud_batch


def augment(point_cloud_batch, label_cloud_batch):
    noise = tf.random.uniform(
        tf.shape(label_cloud_batch), -0.001, 0.001, dtype=tf.float64
    )
    point_cloud_batch += noise[:, :, :3]
    return point_cloud_batch, label_cloud_batch


def generate_dataset(point_clouds, label_clouds, is_training=True):
    dataset = tf.data.Dataset.from_tensor_slices((point_clouds, label_clouds))
    dataset = dataset.shuffle(BATCH_SIZE * 100) if is_training else dataset
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size=BATCH_SIZE)
    dataset = (
        dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        if is_training
        else dataset
    )
    return dataset


split_index = int(len(point_clouds) * (1 - VAL_SPLIT))
train_point_clouds = point_clouds[:split_index]
train_label_cloud = point_cloud_labels[:split_index]
total_training_examples = len(train_point_clouds)

val_point_clouds = point_clouds[split_index:]
val_label_cloud = point_cloud_labels[split_index:]

print("Num train point clouds:", len(train_point_clouds))
print("Num train point cloud labels:", len(train_label_cloud))
print("Num val point clouds:", len(val_point_clouds))
print("Num val point cloud labels:", len(val_label_cloud))

train_dataset = generate_dataset(train_point_clouds, train_label_cloud)
val_dataset = generate_dataset(val_point_clouds, val_label_cloud, is_training=False)

print("Train Dataset:", train_dataset)
print("Validation Dataset:", val_dataset)
Num train point clouds: 2955
Num train point cloud labels: 2955
Num val point clouds: 739
Num val point cloud labels: 739
Train Dataset: <_ParallelMapDataset element_spec=(TensorSpec(shape=(None, 1024, 3), dtype=tf.float64, name=None), TensorSpec(shape=(None, 1024, 5), dtype=tf.float64, name=None))>
Validation Dataset: <_BatchDataset element_spec=(TensorSpec(shape=(None, 1024, 3), dtype=tf.float64, name=None), TensorSpec(shape=(None, 1024, 5), dtype=tf.float64, name=None))>

PointNet 模型


下图描述了 PointNet 型号系列的内部结构:

鉴于 PointNet 将无序的坐标集作为输入数据,因此其架构需要与点云数据的以下特性相匹配:

排列不变性


鉴于点云数据的非结构化性质,由 n 个点组成的扫描有 n 种排列组合。随后的数据处理必须对不同的表示方法保持不变。为了使 PointNet 不受输入排列的影响,我们在将 n 个输入点映射到高维空间后使用了对称函数(如最大池化)。这样就得到了一个全局特征向量,旨在捕捉 n 个输入点的总体特征。
全局特征向量与局部点特征一起用于分割。

变换不变性


如果物体发生某些变换,如平移或缩放,分割输出结果应保持不变。对于给定的输入点云,我们会应用适当的刚性或仿射变换来实现姿态归一化。由于 n 个输入点中的每个点都表示为一个向量,并独立映射到嵌入空间,因此应用几何变换只需将每个点与一个变换矩阵相乘即可。这就是空间变换器网络概念的由来。

构成 T-Net 的操作是受 PointNet 高级架构的启发。MLP(或全连接层)用于将输入点独立且相同地映射到高维空间;最大池化用于编码全局特征向量,然后通过全连接层降低其维度。最后全连接层的输入相关特征与全局可训练权重和偏置相结合,形成一个 3 乘 3 的变换矩阵。

点之间的相互作用


相邻点之间的相互作用往往蕴含着有用的信息(即不应孤立地处理单个点)。分类只需利用全局特征,而分割则必须能够利用局部点特征和全局点特征。

既然知道了 PointNet 模型的组成要素,我们就可以实现该模型了。我们首先要实现基本模块,即卷积模块和多层感知器模块。

def conv_block(x, filters, name):
    x = layers.Conv1D(filters, kernel_size=1, padding="valid", name=f"{name}_conv")(x)
    x = layers.BatchNormalization(name=f"{name}_batch_norm")(x)
    return layers.Activation("relu", name=f"{name}_relu")(x)


def mlp_block(x, filters, name):
    x = layers.Dense(filters, name=f"{name}_dense")(x)
    x = layers.BatchNormalization(name=f"{name}_batch_norm")(x)
    return layers.Activation("relu", name=f"{name}_relu")(x)

我们实施了一个正则化器(取自本例),以加强特征空间的正交性。这是确保变换后的特征幅度不会有太大变化所必需的。

class OrthogonalRegularizer(keras.regularizers.Regularizer):
    """Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""

    def __init__(self, num_features, l2reg=0.001):
        self.num_features = num_features
        self.l2reg = l2reg
        self.identity = keras.ops.eye(num_features)

    def __call__(self, x):
        x = keras.ops.reshape(x, (-1, self.num_features, self.num_features))
        xxt = keras.ops.tensordot(x, x, axes=(2, 2))
        xxt = keras.ops.reshape(xxt, (-1, self.num_features, self.num_features))
        return keras.ops.sum(self.l2reg * keras.ops.square(xxt - self.identity))

    def get_config(self):
        config = super().get_config()
        config.update({"num_features": self.num_features, "l2reg_strength": self.l2reg})
        return config

下一个部分是我们之前介绍过的转换网络。

def transformation_net(inputs, num_features, name):
    """
    Reference: https://keras.io/examples/vision/pointnet/#build-a-model.

    The `filters` values come from the original paper:
    https://arxiv.org/abs/1612.00593.
    """
    x = conv_block(inputs, filters=64, name=f"{name}_1")
    x = conv_block(x, filters=128, name=f"{name}_2")
    x = conv_block(x, filters=1024, name=f"{name}_3")
    x = layers.GlobalMaxPooling1D()(x)
    x = mlp_block(x, filters=512, name=f"{name}_1_1")
    x = mlp_block(x, filters=256, name=f"{name}_2_1")
    return layers.Dense(
        num_features * num_features,
        kernel_initializer="zeros",
        bias_initializer=keras.initializers.Constant(np.eye(num_features).flatten()),
        activity_regularizer=OrthogonalRegularizer(num_features),
        name=f"{name}_final",
    )(x)


def transformation_block(inputs, num_features, name):
    transformed_features = transformation_net(inputs, num_features, name=name)
    transformed_features = layers.Reshape((num_features, num_features))(
        transformed_features
    )
    return layers.Dot(axes=(2, 1), name=f"{name}_mm")([inputs, transformed_features])

最后,我们将上述模块拼接在一起,实现分割模型。

def get_shape_segmentation_model(num_points, num_classes):
    input_points = keras.Input(shape=(None, 3))

    # PointNet Classification Network.
    transformed_inputs = transformation_block(
        input_points, num_features=3, name="input_transformation_block"
    )
    features_64 = conv_block(transformed_inputs, filters=64, name="features_64")
    features_128_1 = conv_block(features_64, filters=128, name="features_128_1")
    features_128_2 = conv_block(features_128_1, filters=128, name="features_128_2")
    transformed_features = transformation_block(
        features_128_2, num_features=128, name="transformed_features"
    )
    features_512 = conv_block(transformed_features, filters=512, name="features_512")
    features_2048 = conv_block(features_512, filters=2048, name="pre_maxpool_block")
    global_features = layers.MaxPool1D(pool_size=num_points, name="global_features")(
        features_2048
    )
    global_features = keras.ops.tile(global_features, [1, num_points, 1])

    # Segmentation head.
    segmentation_input = layers.Concatenate(name="segmentation_input")(
        [
            features_64,
            features_128_1,
            features_128_2,
            transformed_features,
            features_512,
            global_features,
        ]
    )
    segmentation_features = conv_block(
        segmentation_input, filters=128, name="segmentation_features"
    )
    outputs = layers.Conv1D(
        num_classes, kernel_size=1, activation="softmax", name="segmentation_head"
    )(segmentation_features)
    return keras.Model(input_points, outputs)

实例化模型

x, y = next(iter(train_dataset))

num_points = x.shape[1]
num_classes = y.shape[-1]

segmentation_model = get_shape_segmentation_model(num_points, num_classes)
segmentation_model.summary()

演绎展示: 

Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃ Param # ┃ Connected to         ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, None, 3)   │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 64)  │     256 │ input_layer[0][0]    │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 64)  │     256 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 64)  │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 128) │   8,320 │ input_transformatio… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 128) │     512 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 128) │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None,      │ 132,096 │ input_transformatio… │
│ (Conv1D)            │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None,      │   4,096 │ input_transformatio… │
│ (BatchNormalizatio… │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None,      │       0 │ input_transformatio… │
│ (Activation)        │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_max_pooling… │ (None, 1024)      │       0 │ input_transformatio… │
│ (GlobalMaxPooling1… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 512)       │ 524,800 │ global_max_pooling1… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 512)       │   2,048 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 512)       │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 256)       │ 131,328 │ input_transformatio… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 256)       │   1,024 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 256)       │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 9)         │   2,313 │ input_transformatio… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ reshape (Reshape)   │ (None, 3, 3)      │       0 │ input_transformatio… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 3)   │       0 │ input_layer[0][0],   │
│ (Dot)               │                   │         │ reshape[0][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_64_conv    │ (None, None, 64)  │     256 │ input_transformatio… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_64_batch_… │ (None, None, 64)  │     256 │ features_64_conv[0]… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_64_relu    │ (None, None, 64)  │       0 │ features_64_batch_n… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_1_conv │ (None, None, 128) │   8,320 │ features_64_relu[0]… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_1_bat… │ (None, None, 128) │     512 │ features_128_1_conv… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_1_relu │ (None, None, 128) │       0 │ features_128_1_batc… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_2_conv │ (None, None, 128) │  16,512 │ features_128_1_relu… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_2_bat… │ (None, None, 128) │     512 │ features_128_2_conv… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_2_relu │ (None, None, 128) │       0 │ features_128_2_batc… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 64)  │   8,256 │ features_128_2_relu… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 64)  │     256 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 64)  │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │   8,320 │ transformed_feature… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │     512 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None,      │ 132,096 │ transformed_feature… │
│ (Conv1D)            │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None,      │   4,096 │ transformed_feature… │
│ (BatchNormalizatio… │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None,      │       0 │ transformed_feature… │
│ (Activation)        │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_max_pooling… │ (None, 1024)      │       0 │ transformed_feature… │
│ (GlobalMaxPooling1… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 512)       │ 524,800 │ global_max_pooling1… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 512)       │   2,048 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 512)       │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 256)       │ 131,328 │ transformed_feature… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 256)       │   1,024 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 256)       │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 16384)     │ 4,210,… │ transformed_feature… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ reshape_1 (Reshape) │ (None, 128, 128)  │       0 │ transformed_feature… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │       0 │ features_128_2_relu… │
│ (Dot)               │                   │         │ reshape_1[0][0]      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_512_conv   │ (None, None, 512) │  66,048 │ transformed_feature… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_512_batch… │ (None, None, 512) │   2,048 │ features_512_conv[0… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_512_relu   │ (None, None, 512) │       0 │ features_512_batch_… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ pre_maxpool_block_… │ (None, None,      │ 1,050,… │ features_512_relu[0… │
│ (Conv1D)            │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ pre_maxpool_block_… │ (None, None,      │   8,192 │ pre_maxpool_block_c… │
│ (BatchNormalizatio… │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ pre_maxpool_block_… │ (None, None,      │       0 │ pre_maxpool_block_b… │
│ (Activation)        │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_features     │ (None, None,      │       0 │ pre_maxpool_block_r… │
│ (MaxPooling1D)      │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ tile (Tile)         │ (None, None,      │       0 │ global_features[0][… │
│                     │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_input  │ (None, None,      │       0 │ features_64_relu[0]… │
│ (Concatenate)       │ 3008)             │         │ features_128_1_relu… │
│                     │                   │         │ features_128_2_relu… │
│                     │                   │         │ transformed_feature… │
│                     │                   │         │ features_512_relu[0… │
│                     │                   │         │ tile[0][0]           │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_featu… │ (None, None, 128) │ 385,152 │ segmentation_input[… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_featu… │ (None, None, 128) │     512 │ segmentation_featur… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_featu… │ (None, None, 128) │       0 │ segmentation_featur… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_head   │ (None, None, 5)   │     645 │ segmentation_featur… │
│ (Conv1D)            │                   │         │                      │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
 Total params: 7,370,062 (28.11 MB)
 Trainable params: 7,356,110 (28.06 MB)
 Non-trainable params: 13,952 (54.50 KB)

训练


对于训练,作者建议使用学习率计划,即每 20 个历元将初始学习率减半。在本例中,我们使用的是 5 个历元。

steps_per_epoch = total_training_examples // BATCH_SIZE
total_training_steps = steps_per_epoch * EPOCHS
print(f"Steps per epoch: {steps_per_epoch}.")
print(f"Total training steps: {total_training_steps}.")

lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.003,
    decay_steps=steps_per_epoch * 5,
    decay_rate=0.5,
    staircase=True,
)

steps = range(total_training_steps)
lrs = [lr_schedule(step) for step in steps]

plt.plot(lrs)
plt.xlabel("Steps")
plt.ylabel("Learning Rate")
plt.show()

演绎展示:
 

Steps per epoch: 92.
Total training steps: 5520.

最后,我们实现了一个用于运行实验和启动模型训练的实用程序。

def run_experiment(epochs):
    segmentation_model = get_shape_segmentation_model(num_points, num_classes)
    segmentation_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss=keras.losses.CategoricalCrossentropy(),
        metrics=["accuracy"],
    )

    checkpoint_filepath = "checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=True,
    )

    history = segmentation_model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=[checkpoint_callback],
    )

    segmentation_model.load_weights(checkpoint_filepath)
    return segmentation_model, history


segmentation_model, history = run_experiment(epochs=EPOCHS)
Epoch 1/60
  2/93 [37m━━━━━━━━━━━━━━━━━━━━  7s 86ms/step - accuracy: 0.1427 - loss: 48748.8203

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699916678.434176   90326 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

 93/93 ━━━━━━━━━━━━━━━━━━━━ 53s 259ms/step - accuracy: 0.3739 - loss: 27980.7305 - val_accuracy: 0.4340 - val_loss: 10361231.0000
Epoch 2/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 48s 82ms/step - accuracy: 0.6355 - loss: 339.9151 - val_accuracy: 0.3820 - val_loss: 19069320.0000
Epoch 3/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.6695 - loss: 281.5728 - val_accuracy: 0.2859 - val_loss: 15993839.0000
Epoch 4/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.6812 - loss: 253.0939 - val_accuracy: 0.2287 - val_loss: 9633191.0000
Epoch 5/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.6873 - loss: 231.1317 - val_accuracy: 0.3030 - val_loss: 6001454.0000
Epoch 6/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.6860 - loss: 216.6793 - val_accuracy: 0.0620 - val_loss: 1945100.8750
Epoch 7/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.6947 - loss: 210.2683 - val_accuracy: 0.4539 - val_loss: 7908162.5000
Epoch 8/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7014 - loss: 203.2560 - val_accuracy: 0.4035 - val_loss: 17741164.0000
Epoch 9/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7006 - loss: 197.3710 - val_accuracy: 0.1900 - val_loss: 34120616.0000
Epoch 10/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7047 - loss: 192.0777 - val_accuracy: 0.3391 - val_loss: 33157422.0000
Epoch 11/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7102 - loss: 188.4875 - val_accuracy: 0.3394 - val_loss: 4630613.5000
Epoch 12/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7186 - loss: 184.9940 - val_accuracy: 0.1662 - val_loss: 487790.1250
Epoch 13/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7175 - loss: 182.7206 - val_accuracy: 0.1602 - val_loss: 70590.3203
Epoch 14/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7159 - loss: 180.5028 - val_accuracy: 0.1631 - val_loss: 16990.2324
Epoch 15/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7201 - loss: 180.1674 - val_accuracy: 0.2318 - val_loss: 4992.7783
Epoch 16/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7222 - loss: 176.5523 - val_accuracy: 0.6246 - val_loss: 647.5634
Epoch 17/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7291 - loss: 175.6139 - val_accuracy: 0.6551 - val_loss: 324.0956
Epoch 18/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7285 - loss: 175.0228 - val_accuracy: 0.6430 - val_loss: 257.9340
Epoch 19/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7300 - loss: 172.7668 - val_accuracy: 0.6399 - val_loss: 253.2745
Epoch 20/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7316 - loss: 172.9001 - val_accuracy: 0.6084 - val_loss: 232.9293
Epoch 21/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7364 - loss: 170.8767 - val_accuracy: 0.6451 - val_loss: 191.7183
Epoch 22/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7395 - loss: 171.4525 - val_accuracy: 0.6825 - val_loss: 180.2473
Epoch 23/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7392 - loss: 170.1975 - val_accuracy: 0.6095 - val_loss: 180.3243
Epoch 24/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7362 - loss: 169.2144 - val_accuracy: 0.6017 - val_loss: 178.3013
Epoch 25/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7409 - loss: 169.2571 - val_accuracy: 0.6582 - val_loss: 178.3481
Epoch 26/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7415 - loss: 167.7480 - val_accuracy: 0.6808 - val_loss: 177.8774
Epoch 27/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7440 - loss: 167.7844 - val_accuracy: 0.7131 - val_loss: 176.5841
Epoch 28/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7423 - loss: 167.5307 - val_accuracy: 0.6891 - val_loss: 176.1687
Epoch 29/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7409 - loss: 166.4581 - val_accuracy: 0.7136 - val_loss: 174.9417
Epoch 30/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7419 - loss: 165.9243 - val_accuracy: 0.7407 - val_loss: 173.0663
Epoch 31/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7471 - loss: 166.9746 - val_accuracy: 0.7454 - val_loss: 172.9663
Epoch 32/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7472 - loss: 165.9707 - val_accuracy: 0.7480 - val_loss: 173.9868
Epoch 33/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7443 - loss: 165.9368 - val_accuracy: 0.7076 - val_loss: 174.4526
Epoch 34/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7496 - loss: 165.5322 - val_accuracy: 0.7441 - val_loss: 174.6099
Epoch 35/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7453 - loss: 164.2007 - val_accuracy: 0.7469 - val_loss: 174.2793
Epoch 36/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7503 - loss: 165.3418 - val_accuracy: 0.7469 - val_loss: 174.0812
Epoch 37/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7491 - loss: 164.4796 - val_accuracy: 0.7524 - val_loss: 173.9656
Epoch 38/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 10s 82ms/step - accuracy: 0.7489 - loss: 164.4573 - val_accuracy: 0.7516 - val_loss: 175.3401
Epoch 39/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7437 - loss: 163.4484 - val_accuracy: 0.7532 - val_loss: 173.8172
Epoch 40/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7507 - loss: 163.6720 - val_accuracy: 0.7537 - val_loss: 173.9127
Epoch 41/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7506 - loss: 164.0555 - val_accuracy: 0.7556 - val_loss: 173.0979
Epoch 42/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7517 - loss: 164.1554 - val_accuracy: 0.7562 - val_loss: 172.8895
Epoch 43/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 10s 82ms/step - accuracy: 0.7527 - loss: 164.6351 - val_accuracy: 0.7567 - val_loss: 173.0476
Epoch 44/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7505 - loss: 164.1568 - val_accuracy: 0.7571 - val_loss: 172.2751
Epoch 45/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7500 - loss: 163.8129 - val_accuracy: 0.7579 - val_loss: 171.8897
Epoch 46/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7534 - loss: 163.6473 - val_accuracy: 0.7577 - val_loss: 172.5457
Epoch 47/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7510 - loss: 163.7318 - val_accuracy: 0.7580 - val_loss: 172.2256
Epoch 48/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7517 - loss: 163.3274 - val_accuracy: 0.7575 - val_loss: 172.3276
Epoch 49/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7511 - loss: 163.5069 - val_accuracy: 0.7581 - val_loss: 171.2155
Epoch 50/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7507 - loss: 163.7366 - val_accuracy: 0.7578 - val_loss: 171.1100
Epoch 51/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7519 - loss: 163.1190 - val_accuracy: 0.7580 - val_loss: 171.7971
Epoch 52/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 81ms/step - accuracy: 0.7510 - loss: 162.7351 - val_accuracy: 0.7579 - val_loss: 171.9780
Epoch 53/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7510 - loss: 162.9639 - val_accuracy: 0.7577 - val_loss: 171.6770
Epoch 54/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7530 - loss: 162.7419 - val_accuracy: 0.7578 - val_loss: 170.5556
Epoch 55/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7515 - loss: 163.2893 - val_accuracy: 0.7582 - val_loss: 171.9172
Epoch 56/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7505 - loss: 164.2843 - val_accuracy: 0.7584 - val_loss: 171.9182
Epoch 57/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7498 - loss: 162.6679 - val_accuracy: 0.7587 - val_loss: 173.7610
Epoch 58/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7523 - loss: 163.3332 - val_accuracy: 0.7585 - val_loss: 172.5207
Epoch 59/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7529 - loss: 162.4575 - val_accuracy: 0.7586 - val_loss: 171.6861
Epoch 60/60
 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7498 - loss: 162.9523 - val_accuracy: 0.7586 - val_loss: 172.3012

直观了解培训情况

def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("loss")
plot_result("accuracy")

推论

validation_batch = next(iter(val_dataset))
val_predictions = segmentation_model.predict(validation_batch[0])
print(f"Validation prediction shape: {val_predictions.shape}")


def visualize_single_point_cloud(point_clouds, label_clouds, idx):
    label_map = LABELS + ["none"]
    point_cloud = point_clouds[idx]
    label_cloud = label_clouds[idx]
    visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])


idx = np.random.choice(len(validation_batch[0]))
print(f"Index selected: {idx}")

# Plotting with ground-truth.
visualize_single_point_cloud(validation_batch[0], validation_batch[1], idx)

# Plotting with predicted labels.
visualize_single_point_cloud(validation_batch[0], val_predictions, idx)

演绎展示:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
Validation prediction shape: (32, 1024, 5)
Index selected: 26

最后说明


如果您有兴趣了解有关此主题的更多信息,您可能会发现本资料库非常有用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/563071.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【PCL】教程 implicit_shape_model.cpp 3D点云数据的对象识别 利用隐式形状模型进行训练和识别...

ism_test_cat.pcd 参数&#xff1a;ism_train_cat.pcd 0 ism_train_horse.pcd 1 ism_train_lioness.pcd 2 ism_train_michael.pcd 3 ism_train_wolf.pcd 4 ism_test_cat.pcd 0 这里红点表示对应感兴趣类别的对象预测中心 ./ism_t…

字节FE:JavaScript学习路线图

JavaScript简介 JavaScript是一种高级的、解释执行的编程语言。它是互联网的三大核心技术之一&#xff0c;与HTML和CSS一同工作&#xff0c;用于创建交互式的网页。JavaScript被所有现代网页浏览器支持而不需要任何插件。它可以增强用户界面和网页的交互性&#xff0c;可以进行…

【讲解下Spring Boot单元测试】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

FineBi中创建自定义的图表

FineBi中增加自己的自定义图表组件,比如: 的相关笔记: 1 获取有哪些BI自定义图表组件:http://localhost:8080/webroot/decision/v5/plugin/custom/component/list?_=1713667435473[{"name": "图表DEMO_EK","chartType": "amap_demo&q…

GO环境及入门案例

文章目录 简介一、win GO开发环境安装二、Linux go运行环境二、GO代码入门2.1 导包案例2.2 赋值2.3 变量、函数2.4 三方库使用 简介 go不是面向对象语言&#xff0c; 其指针、结构体等比较像C&#xff0c;知名的go 开源项目有docker k8s prometheus node-exporter等 一、win …

如何在3dMax中快速打包mzp 文件?

如何在3dMax中创建mzp 文件&#xff1f; 我喜欢将我的Maxscript脚本发布为mzp文件。这是一个为3dMax构建的自解压zip文件。在mzp文件中&#xff0c;您可以捆绑Maxscript脚本文件、图片、预设或其他文件&#xff0c;并链接安装时执行的特殊操作。 在3dMax中使用大型脚本时&…

耐高温300度锅炉轴承,江苏鲁岳轴承制造的行业标杆

自润滑轴承-产品类型-耐高温轴承-不锈钢轴承-江苏鲁岳轴承制造有限公司。锅炉轴承&#xff0c;耐高温至200度-800度。 江苏鲁岳轴承制造有限公司&#xff0c;一家专注于锅炉轴承和耐高温轴承的研发与生产的企业&#xff0c;致力于为客户提供高质量、高性能的轴承解决方案。其中…

LeetCode题练习与总结:矩阵置零--73

一、题目描述 给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,1,1],[1,0,1],[1,1,1]] 输出&#xff1a;[[1,0,1],[0,0,0],[1,0,1]]示例 2&#xf…

Linux-内存文件

1. 基础IO操作 1.1 c语言的IO接口 fopen&#xff1a;打开一个文件&#xff0c;按照指定方式 参数&#xff1a;filename 文件名&#xff0c;也可以是路径&#xff0c;mode&#xff1a;打开方式 返回打开的文件指针 fread&#xff1a;从指定流中读数据 参数&#xff1a;从FIL…

Selenium web自动化测试环境搭建

Selenium web自动化环境搭建主要要经历以下几个步骤&#xff1a; 1、安装python 在python官网&#xff1a;Welcome to Python.org&#xff0c;根据各自对应平台如&#xff1a;windows&#xff0c;下载相应的python版本。 ​ 下载成功后&#xff0c;点击安装包&#xff0c;一直…

解释一下“暂存区”的概念,在Git中它扮演什么角色?

文章目录 暂存区在Git中的概念与作用什么是暂存区&#xff08;Staging Area&#xff09;暂存区的位置和结构 暂存区在Git工作流程中的角色1. 分离工作区与版本库的交互示例代码与操作步骤示例1&#xff1a;将工作区的修改添加至暂存区 2. 控制提交内容的粒度示例2&#xff1a;分…

【Linux】虚拟机与Xshell及VS Code的连接

一、基础环境 虚拟机&#xff1a;VMware Workstation Pro 虚拟机镜像&#xff1a;ubuntu-18.04.5-desktop-amd64.iso 其他&#xff1a;Xshell 6、Xftp 6、Visual Studio Code 上述软件的安装操作不再赘述&#xff0c;CSDN上有大量的优秀博文&#xff0c;可参考&#xff1a;详细…

安装和部署maven

准备工作 maven下载地址&#xff1a;https://maven.apache.org/download.cgi 使用wget将maven包下载到linux环境上&#xff0c;/toos/ 目录下&#xff08;也可用迅雷&#xff09; wget https://dlcdn.apache.org/maven/maven-3/3.9.6/binaries/apache-maven-3.9.6-bin.tar.g…

小游戏贪吃蛇的实现之C语言版

找往期文章包括但不限于本期文章中不懂的知识点&#xff1a; 个人主页&#xff1a;我要学编程(ಥ_ಥ)-CSDN博客 所属专栏&#xff1a;C语言 目录 游戏前期准备&#xff1a; 设置控制台相关的信息 GetStdHandle GetConsoleCursorInfo SetConsoleCursorInfo SetConsoleCu…

VSCode插件开发学习

一、环境准备 0、参考文档&#xff1a;VS Code插件创作中文开发文档 1、大于18版本的nodejs 2、安装Yeoman和VS Code Extension Generator&#xff1a; npm install -g yo generator-code 3、生成脚手架 yo code 选择内容&#xff1a; ? What type of extension do yo…

GPT-3.5 Turbo 的 temperature 设置为 0 就是贪婪解码?

&#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 将 GPT-3.5 Turbo 的 temperature 设置为 0 通常意味着采用贪婪解码&#xff08;greedy decoding&#xff09;策略。在贪婪解码中&#xff0c;模型在每一步生成文本时选择概率最高的词元&#xff0c;从…

微调Llama3实践并基于Llama3构建心理咨询EmoLLM

Llama3 Xtuner微调Llama3 EmoLLM 心理咨询师

LabVIEW多设备控制与数据采集系统

LabVIEW多设备控制与数据采集系统 随着科技的进步&#xff0c;自动化测试与控制系统在工业、科研等领域的应用越来越广泛。开发了一种基于LabVIEW平台开发的多设备控制与数据采集系统&#xff0c;旨在解决多设备手动设置复杂、多路数据显示不直观、数据存储不便等问题。通过RS…

基于STM32的蓝牙小车的Proteus仿真(虚拟串口模拟)

文章目录 一、前言二、仿真图1.要求2.思路3.画图3.1 电源部分3.2 超声波测距部分3.3 电机驱动部分3.4 按键部分3.5 蓝牙部分3.6 显示屏部分3.7 整体 4.仿真5.软件 三、总结 一、前言 proteus本身并不支持蓝牙仿真&#xff0c;这里我采用虚拟串口的方式来模拟蓝牙控制。 这里给…

42岁TVB男艺人曾靠刘德华贴钱出道,苦熬10年终上位

张颕康在无线&#xff08;TVB&#xff09;电视打滚多年&#xff0c;近年在《逆天奇案》第一、二辑凭扎实演技为人留下印象。他还是圈中出名的「爱妻号」&#xff0c;日前在访问期间&#xff0c;张颕康三句不离多谢太太。 较年长的观众或会记得&#xff0c;张颕康初出道以「刘德…