政安晨:【Keras机器学习实践要点】(二十二)—— 基于 TPU 的肺炎分类

目录

简述

介绍 / 布置

加载数据

可视化数据集

建立 CNN

纠正数据失衡

训练模型

拟合模型

可视化模型性能

​编辑预测和评估结果


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:基于 TPU 的医学图像分类。

简述

Keras是一个高级神经网络库,可以用于实现医学图像分类任务。医学图像分类是指将医学图像分为不同的类别,例如正常和异常,不同病种等。

在Keras中,可以使用卷积神经网络(CNN)来进行医学图像分类。CNN是一种特别适用于图像分类任务的神经网络架构,它能够有效地提取图像中的特征。

首先,需要准备医学图像数据集。可以使用公开的医学图像数据集,例如MNIST(手写数字图像)或者ImageNet(包含各种物体的图像)。另外,还可以使用自己收集的医学图像数据集。

接下来,需要定义CNN模型。可以使用Keras提供的各种层(例如卷积层、池化层、全连接层)来构建模型。卷积层可以有效地提取图像中的局部特征,池化层可以降低特征图的尺寸,全连接层可以用于最终的分类。

然后,需要编译模型。可以选择合适的损失函数和优化器来训练模型。对于多分类问题,可以使用交叉熵损失函数,常见的优化器包括Adam、SGD等。

接下来,需要进行模型训练。可以使用数据集中的图像进行训练,同时还需要准备对应的标签(即图像所属的类别)。可以使用Keras提供的fit函数来进行训练。

最后,可以使用模型进行预测。可以输入新的医学图像数据,使用模型预测其所属的类别。可以使用Keras提供的predict函数来进行预测。

总的来说,Keras提供了一个简单且强大的工具,可以用于医学图像分类任务。通过定义、编译、训练和预测模型,可以提高医学图像分类的准确性和效率,为医学诊断和研究提供有力支持。

介绍 / 布置

本文将介绍如何建立 X 光图像分类模型,以预测 X 光扫描是否显示存在肺炎。

import re
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    print("Device:", tpu.master())
    strategy = tf.distribute.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)

演绎:

Device: grpc://10.0.27.122:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470

INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470

INFO:tensorflow:Clearing out eager caches

INFO:tensorflow:Clearing out eager caches

INFO:tensorflow:Finished initializing TPU system.

INFO:tensorflow:Finished initializing TPU system.
WARNING:absl:[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is deprecated, please use  the non experimental symbol [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) instead.

INFO:tensorflow:Found TPU system:

INFO:tensorflow:Found TPU system:

INFO:tensorflow:*** Num TPU Cores: 8

INFO:tensorflow:*** Num TPU Cores: 8

INFO:tensorflow:*** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Cores Per Worker: 8

INFO:tensorflow:*** Num TPU Cores Per Worker: 8

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

Number of replicas: 8

我们需要谷歌云链接到我们的数据,以便使用 TPU 加载数据。下面,我们将定义本示例中使用的关键配置参数。要在 TPU 上运行,本示例必须在 Colab 上选择 TPU 运行时。

AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 25 * strategy.num_replicas_in_sync
IMAGE_SIZE = [180, 180]
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]

加载数据

我们使用的 Cell 胸部 X 光数据将数据分为训练文件和测试文件。首先加载训练 TFRecords。

train_images = tf.data.TFRecordDataset(
    "gs://download.tensorflow.org/data/ChestXRay2017/train/images.tfrec"
)
train_paths = tf.data.TFRecordDataset(
    "gs://download.tensorflow.org/data/ChestXRay2017/train/paths.tfrec"
)

ds = tf.data.Dataset.zip((train_images, train_paths))

让我们数一数有多少张健康/正常的胸部 X 光片,有多少张肺炎胸部 X 光片:

COUNT_NORMAL = len(
    [
        filename
        for filename in train_paths
        if "NORMAL" in filename.numpy().decode("utf-8")
    ]
)
print("Normal images count in training set: " + str(COUNT_NORMAL))

COUNT_PNEUMONIA = len(
    [
        filename
        for filename in train_paths
        if "PNEUMONIA" in filename.numpy().decode("utf-8")
    ]
)
print("Pneumonia images count in training set: " + str(COUNT_PNEUMONIA))

结果:

Normal images count in training set: 1349
Pneumonia images count in training set: 3883

请注意,被归类为肺炎的图像要比正常图像多得多。这说明我们的数据不平衡。我们稍后将在笔记本中纠正这种不平衡。
我们要将每个文件名映射到相应的(图像、标签)对。以下方法将帮助我们实现这一目标。
由于只有两个标签,我们将对标签进行编码,使 1 或 True 表示肺炎,0 或 False 表示正常。

def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, "/")
    # The second to last is the class-directory
    if parts[-2] == "PNEUMONIA":
        return 1
    else:
        return 0


def decode_img(img):
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    # resize the image to the desired size.
    return tf.image.resize(img, IMAGE_SIZE)


def process_path(image, path):
    label = get_label(path)
    # load the raw data from the file as a string
    img = decode_img(image)
    return img, label


ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)

让我们把数据分成训练数据集和验证数据集。

ds = ds.shuffle(10000)
train_ds = ds.take(4200)
val_ds = ds.skip(4200)

让我们来想象一下(图像、标签)对的形状。

for image, label in train_ds.take(1):
    print("Image shape: ", image.numpy().shape)
    print("Label: ", label.numpy())

结果:

Image shape:  (180, 180, 3)
Label:  False

同时加载并格式化测试数据。

test_images = tf.data.TFRecordDataset(
    "gs://download.tensorflow.org/data/ChestXRay2017/test/images.tfrec"
)
test_paths = tf.data.TFRecordDataset(
    "gs://download.tensorflow.org/data/ChestXRay2017/test/paths.tfrec"
)
test_ds = tf.data.Dataset.zip((test_images, test_paths))

test_ds = test_ds.map(process_path, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE)

可视化数据集

首先,让我们使用缓冲预取,这样就能从磁盘获取数据,而不会出现 I/O 阻塞。
请注意,大型图像数据集不应缓存在内存中。我们在这里这样做是因为数据集不是很大,而且我们想在 TPU 上进行训练。

def prepare_for_training(ds, cache=True):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()

    ds = ds.batch(BATCH_SIZE)

    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training.
    ds = ds.prefetch(buffer_size=AUTOTUNE)

    return ds

调用下一批迭代训练数据。

train_ds = prepare_for_training(train_ds)
val_ds = prepare_for_training(val_ds)

image_batch, label_batch = next(iter(train_ds))

定义在批次中显示图像的方法。

def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10, 10))
    for n in range(25):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n] / 255)
        if label_batch[n]:
            plt.title("PNEUMONIA")
        else:
            plt.title("NORMAL")
        plt.axis("off")

由于该方法将 NumPy 数组作为参数,因此在批次上调用 numpy 函数,以 NumPy 数组形式返回张量。

show_batch(image_batch.numpy(), label_batch.numpy())

建立 CNN

为了使我们的模型更加模块化和易于理解,让我们定义一些模块。在构建卷积神经网络时,我们将创建一个卷积块和一个密集层块。

import os 
os.environ['KERAS_BACKEND'] = 'tensorflow'

import keras
from keras import layers

def conv_block(filters, inputs):
    x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(inputs)
    x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(x)
    x = layers.BatchNormalization()(x)
    outputs = layers.MaxPool2D()(x)

    return outputs


def dense_block(units, dropout_rate, inputs):
    x = layers.Dense(units, activation="relu")(inputs)
    x = layers.BatchNormalization()(x)
    outputs = layers.Dropout(dropout_rate)(x)

    return outputs

下面的方法将定义为我们建立模型的函数。

图像的原始值范围为 [0,255]。CNN 在使用较小的数值时效果更好,因此我们将缩小输入范围。

剔除层非常重要,因为它们可以降低模型过拟合的可能性。我们希望以一个节点的密集层结束模型,因为这将是决定 X 光片是否显示肺炎的二进制输出。

def build_model():
    inputs = keras.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
    x = layers.Rescaling(1.0 / 255)(inputs)
    x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)
    x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)
    x = layers.MaxPool2D()(x)

    x = conv_block(32, x)
    x = conv_block(64, x)

    x = conv_block(128, x)
    x = layers.Dropout(0.2)(x)

    x = conv_block(256, x)
    x = layers.Dropout(0.2)(x)

    x = layers.Flatten()(x)
    x = dense_block(512, 0.7, x)
    x = dense_block(128, 0.5, x)
    x = dense_block(64, 0.3, x)

    outputs = layers.Dense(1, activation="sigmoid")(x)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

纠正数据失衡

在本示例的前面部分,我们看到数据是不平衡的,被归类为肺炎的图像多于正常图像。

我们将通过类别加权来纠正这种情况:

initial_bias = np.log([COUNT_PNEUMONIA / COUNT_NORMAL])
print("Initial bias: {:.5f}".format(initial_bias[0]))

TRAIN_IMG_COUNT = COUNT_NORMAL + COUNT_PNEUMONIA
weight_for_0 = (1 / COUNT_NORMAL) * (TRAIN_IMG_COUNT) / 2.0
weight_for_1 = (1 / COUNT_PNEUMONIA) * (TRAIN_IMG_COUNT) / 2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print("Weight for class 0: {:.2f}".format(weight_for_0))
print("Weight for class 1: {:.2f}".format(weight_for_1))

结果如下:

Initial bias: 1.05724
Weight for class 0: 1.94
Weight for class 1: 0.67

类别 0(正常)的权重远远高于类别 1(肺炎)的权重。因为正常图像的数量较少,所以每个正常图像的权重会更高,以平衡数据,因为 CNN 在训练数据平衡的情况下工作效果最佳。

训练模型

定义回调
检查点回调会保存模型的最佳权重,这样下次我们想使用模型时,就不必再花时间训练它了。早期停止回调会在模型开始停滞不前,或者更糟糕的是模型开始过度拟合时停止训练过程。

checkpoint_cb = keras.callbacks.ModelCheckpoint("xray_model.keras", save_best_only=True)

early_stopping_cb = keras.callbacks.EarlyStopping(
    patience=10, restore_best_weights=True
)

我们还需要调整学习率。学习率太高会导致模型发散。学习率太小会导致模型太慢。我们将在下文中采用指数学习率调度方法。

initial_learning_rate = 0.015
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)

拟合模型


对于我们的指标,我们希望包括精确度和召回率,因为它们能让我们更清楚地了解我们的模型有多好。精确度告诉我们有多少标签是正确的。由于我们的数据并不均衡,因此准确率可能会对一个好的模型产生偏差(例如,一个总是预测 PNEUMONIA 的模型准确率为 74%,但并不是一个好的模型)。

精确度是真阳性(TP)与假阳性(FP)之和的比值。它显示了标签阳性结果中真正正确的比例。

Recall 是 TP 与假阴性(FN)之和的 TP 数。它显示了实际阳性结果中正确率的百分比。

由于图像只有两种可能的标签,我们将使用二元交叉熵损失。在拟合模型时,请记住指定我们之前定义的类权重。因为我们使用的是 TPU,所以训练时间会很快,不到 2 分钟。

with strategy.scope():
    model = build_model()

    METRICS = [
        keras.metrics.BinaryAccuracy(),
        keras.metrics.Precision(name="precision"),
        keras.metrics.Recall(name="recall"),
    ]
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss="binary_crossentropy",
        metrics=METRICS,
    )

history = model.fit(
    train_ds,
    epochs=100,
    validation_data=val_ds,
    class_weight=class_weight,
    callbacks=[checkpoint_cb, early_stopping_cb],
)
Epoch 1/100
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

21/21 [==============================] - 12s 568ms/step - loss: 0.5857 - binary_accuracy: 0.6960 - precision: 0.8887 - recall: 0.6733 - val_loss: 34.0149 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000
Epoch 2/100
21/21 [==============================] - 3s 128ms/step - loss: 0.2916 - binary_accuracy: 0.8755 - precision: 0.9540 - recall: 0.8738 - val_loss: 97.5194 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000
Epoch 3/100
21/21 [==============================] - 4s 167ms/step - loss: 0.2384 - binary_accuracy: 0.9002 - precision: 0.9663 - recall: 0.8964 - val_loss: 27.7902 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000
Epoch 4/100
21/21 [==============================] - 4s 173ms/step - loss: 0.2046 - binary_accuracy: 0.9145 - precision: 0.9725 - recall: 0.9102 - val_loss: 10.8302 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000
Epoch 5/100
21/21 [==============================] - 4s 174ms/step - loss: 0.1841 - binary_accuracy: 0.9279 - precision: 0.9733 - recall: 0.9279 - val_loss: 3.5860 - val_binary_accuracy: 0.7103 - val_precision: 0.7162 - val_recall: 0.9879
Epoch 6/100
21/21 [==============================] - 4s 185ms/step - loss: 0.1600 - binary_accuracy: 0.9362 - precision: 0.9791 - recall: 0.9337 - val_loss: 0.3014 - val_binary_accuracy: 0.8895 - val_precision: 0.8973 - val_recall: 0.9555
Epoch 7/100
21/21 [==============================] - 3s 130ms/step - loss: 0.1567 - binary_accuracy: 0.9393 - precision: 0.9798 - recall: 0.9372 - val_loss: 0.6763 - val_binary_accuracy: 0.7810 - val_precision: 0.7760 - val_recall: 0.9771
Epoch 8/100
21/21 [==============================] - 3s 131ms/step - loss: 0.1532 - binary_accuracy: 0.9421 - precision: 0.9825 - recall: 0.9385 - val_loss: 0.3169 - val_binary_accuracy: 0.8895 - val_precision: 0.8684 - val_recall: 0.9973
Epoch 9/100
21/21 [==============================] - 4s 184ms/step - loss: 0.1457 - binary_accuracy: 0.9431 - precision: 0.9822 - recall: 0.9401 - val_loss: 0.2064 - val_binary_accuracy: 0.9273 - val_precision: 0.9840 - val_recall: 0.9136
Epoch 10/100
21/21 [==============================] - 3s 132ms/step - loss: 0.1201 - binary_accuracy: 0.9521 - precision: 0.9869 - recall: 0.9479 - val_loss: 0.4364 - val_binary_accuracy: 0.8605 - val_precision: 0.8443 - val_recall: 0.9879
Epoch 11/100
21/21 [==============================] - 3s 127ms/step - loss: 0.1200 - binary_accuracy: 0.9510 - precision: 0.9863 - recall: 0.9469 - val_loss: 0.5197 - val_binary_accuracy: 0.8508 - val_precision: 1.0000 - val_recall: 0.7922
Epoch 12/100
21/21 [==============================] - 4s 186ms/step - loss: 0.1077 - binary_accuracy: 0.9581 - precision: 0.9870 - recall: 0.9559 - val_loss: 0.1349 - val_binary_accuracy: 0.9486 - val_precision: 0.9587 - val_recall: 0.9703
Epoch 13/100
21/21 [==============================] - 4s 173ms/step - loss: 0.0918 - binary_accuracy: 0.9650 - precision: 0.9914 - recall: 0.9611 - val_loss: 0.0926 - val_binary_accuracy: 0.9700 - val_precision: 0.9837 - val_recall: 0.9744
Epoch 14/100
21/21 [==============================] - 3s 130ms/step - loss: 0.0996 - binary_accuracy: 0.9612 - precision: 0.9913 - recall: 0.9559 - val_loss: 0.1811 - val_binary_accuracy: 0.9419 - val_precision: 0.9956 - val_recall: 0.9231
Epoch 15/100
21/21 [==============================] - 3s 129ms/step - loss: 0.0898 - binary_accuracy: 0.9643 - precision: 0.9901 - recall: 0.9614 - val_loss: 0.1525 - val_binary_accuracy: 0.9486 - val_precision: 0.9986 - val_recall: 0.9298
Epoch 16/100
21/21 [==============================] - 3s 128ms/step - loss: 0.0941 - binary_accuracy: 0.9621 - precision: 0.9904 - recall: 0.9582 - val_loss: 0.5101 - val_binary_accuracy: 0.8527 - val_precision: 1.0000 - val_recall: 0.7949
Epoch 17/100
21/21 [==============================] - 3s 125ms/step - loss: 0.0798 - binary_accuracy: 0.9636 - precision: 0.9897 - recall: 0.9607 - val_loss: 0.1239 - val_binary_accuracy: 0.9622 - val_precision: 0.9875 - val_recall: 0.9595
Epoch 18/100
21/21 [==============================] - 3s 126ms/step - loss: 0.0821 - binary_accuracy: 0.9657 - precision: 0.9911 - recall: 0.9623 - val_loss: 0.1597 - val_binary_accuracy: 0.9322 - val_precision: 0.9956 - val_recall: 0.9096
Epoch 19/100
21/21 [==============================] - 3s 143ms/step - loss: 0.0800 - binary_accuracy: 0.9657 - precision: 0.9917 - recall: 0.9617 - val_loss: 0.2538 - val_binary_accuracy: 0.9109 - val_precision: 1.0000 - val_recall: 0.8758
Epoch 20/100
21/21 [==============================] - 3s 127ms/step - loss: 0.0605 - binary_accuracy: 0.9738 - precision: 0.9950 - recall: 0.9694 - val_loss: 0.6594 - val_binary_accuracy: 0.8566 - val_precision: 1.0000 - val_recall: 0.8003
Epoch 21/100
21/21 [==============================] - 4s 167ms/step - loss: 0.0726 - binary_accuracy: 0.9733 - precision: 0.9937 - recall: 0.9701 - val_loss: 0.0593 - val_binary_accuracy: 0.9816 - val_precision: 0.9945 - val_recall: 0.9798
Epoch 22/100
21/21 [==============================] - 3s 126ms/step - loss: 0.0577 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 0.1087 - val_binary_accuracy: 0.9729 - val_precision: 0.9931 - val_recall: 0.9690
Epoch 23/100
21/21 [==============================] - 3s 125ms/step - loss: 0.0652 - binary_accuracy: 0.9729 - precision: 0.9924 - recall: 0.9707 - val_loss: 1.8465 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000
Epoch 24/100
21/21 [==============================] - 3s 124ms/step - loss: 0.0538 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 1.5769 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000
Epoch 25/100
21/21 [==============================] - 4s 167ms/step - loss: 0.0549 - binary_accuracy: 0.9776 - precision: 0.9954 - recall: 0.9743 - val_loss: 0.0590 - val_binary_accuracy: 0.9777 - val_precision: 0.9904 - val_recall: 0.9784
Epoch 26/100
21/21 [==============================] - 3s 131ms/step - loss: 0.0677 - binary_accuracy: 0.9719 - precision: 0.9924 - recall: 0.9694 - val_loss: 2.6008 - val_binary_accuracy: 0.6928 - val_precision: 0.9977 - val_recall: 0.5735
Epoch 27/100
21/21 [==============================] - 3s 127ms/step - loss: 0.0469 - binary_accuracy: 0.9833 - precision: 0.9971 - recall: 0.9804 - val_loss: 1.0184 - val_binary_accuracy: 0.8605 - val_precision: 0.9983 - val_recall: 0.8070
Epoch 28/100
21/21 [==============================] - 3s 126ms/step - loss: 0.0501 - binary_accuracy: 0.9790 - precision: 0.9961 - recall: 0.9755 - val_loss: 0.3737 - val_binary_accuracy: 0.9089 - val_precision: 0.9954 - val_recall: 0.8772
Epoch 29/100
21/21 [==============================] - 3s 128ms/step - loss: 0.0548 - binary_accuracy: 0.9798 - precision: 0.9941 - recall: 0.9784 - val_loss: 1.2928 - val_binary_accuracy: 0.7907 - val_precision: 1.0000 - val_recall: 0.7085
Epoch 30/100
21/21 [==============================] - 3s 129ms/step - loss: 0.0370 - binary_accuracy: 0.9860 - precision: 0.9980 - recall: 0.9829 - val_loss: 0.1370 - val_binary_accuracy: 0.9612 - val_precision: 0.9972 - val_recall: 0.9487
Epoch 31/100
21/21 [==============================] - 3s 125ms/step - loss: 0.0585 - binary_accuracy: 0.9819 - precision: 0.9951 - recall: 0.9804 - val_loss: 1.1955 - val_binary_accuracy: 0.6870 - val_precision: 0.9976 - val_recall: 0.5655
Epoch 32/100
21/21 [==============================] - 3s 140ms/step - loss: 0.0813 - binary_accuracy: 0.9695 - precision: 0.9934 - recall: 0.9652 - val_loss: 1.0394 - val_binary_accuracy: 0.8576 - val_precision: 0.9853 - val_recall: 0.8138
Epoch 33/100
21/21 [==============================] - 3s 128ms/step - loss: 0.1111 - binary_accuracy: 0.9555 - precision: 0.9870 - recall: 0.9524 - val_loss: 4.9438 - val_binary_accuracy: 0.5911 - val_precision: 1.0000 - val_recall: 0.4305
Epoch 34/100
21/21 [==============================] - 3s 130ms/step - loss: 0.0680 - binary_accuracy: 0.9726 - precision: 0.9921 - recall: 0.9707 - val_loss: 2.8822 - val_binary_accuracy: 0.7267 - val_precision: 0.9978 - val_recall: 0.6208
Epoch 35/100
21/21 [==============================] - 4s 187ms/step - loss: 0.0784 - binary_accuracy: 0.9712 - precision: 0.9892 - recall: 0.9717 - val_loss: 0.3940 - val_binary_accuracy: 0.9390 - val_precision: 0.9942 - val_recall: 0.9204

可视化模型性能

让我们绘制训练集和验证集的模型准确率和损失图。请注意,本笔记本没有指定随机种子。对于您的笔记本电脑,可能会有轻微差异。

fig, ax = plt.subplots(1, 4, figsize=(20, 3))
ax = ax.ravel()

for i, met in enumerate(["precision", "recall", "binary_accuracy", "loss"]):
    ax[i].plot(history.history[met])
    ax[i].plot(history.history["val_" + met])
    ax[i].set_title("Model {}".format(met))
    ax[i].set_xlabel("epochs")
    ax[i].set_ylabel(met)
    ax[i].legend(["train", "val"])

我们看到,我们模型的准确率约为 95%。

预测和评估结果

让我们在测试数据上对模型进行评估!

model.evaluate(test_ds, return_dict=True)
4/4 [==============================] - 3s 708ms/step - loss: 0.9718 - binary_accuracy: 0.7901 - precision: 0.7524 - recall: 0.9897

{'binary_accuracy': 0.7900640964508057,
 'loss': 0.9717951416969299,
 'precision': 0.752436637878418,
 'recall': 0.9897436499595642}

我们发现,测试数据的准确率低于验证集的准确率。这可能表明拟合过度。

我们的召回率大于精确率,这表明几乎所有肺炎图像都被正确识别,但一些正常图像被错误识别。我们应该努力提高精确度。

for image, label in test_ds.take(1):
    plt.imshow(image[0] / 255.0)
    plt.title(CLASS_NAMES[label[0].numpy()])

prediction = model.predict(test_ds.take(1))[0]
scores = [1 - prediction, prediction]

for score, name in zip(scores, CLASS_NAMES):
    print("This image is %.2f percent %s" % ((100 * score), name))

执行如下:

/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index
  This is separate from the ipykernel package so we can avoid doing imports until

This image is 47.19 percent NORMAL
This image is 52.81 percent PNEUMONIA


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/527598.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

达梦(DM)报错[-3209]: 无效的存储参数

[TOC](达梦(DM)报错[-3209]: 无效的存储参数) 最近有一个项目,一直使用的是达梦数据库,今天遇到了一个问题,就是将测试环境新增加的表导入线上时报错 [-3209]: 无效的存储参数,这里我用我本地的达梦数据库复现一下这个问题&#…

【HTML】简单制作一个动态变色光束花

目录 前言 开始 HTML部分 效果图 ​编辑​编辑​编辑​编辑总结 前言 无需多言,本文将详细介绍一段代码,具体内容如下: 开始 首先新建文件夹,创建一个文本文档,其中HTML的文件名改为[index.html]&a…

python-flask后端知识点

anki 简单介绍: 在当今信息爆炸的时代,学习已经不再仅仅是获取知识,更是一项关于有效性和持续性的挑战。幸运的是,我们有幸生活在一个科技日新月异的时代,而ANKI(Anki)正是一款旗舰级的学习工具…

【深度学习】环境搭建ubuntu22.04

清华官网的conda源 https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/ 安装torch conda install pytorch torchvision torchaudio pytorch-cuda12.1 -c pytorch -c nvidia 2.2.2 conda install 指引看这里: ref:https://docs.nvidia.com/cuda/cuda-installatio…

高创新 | Matlab实现OOA-CNN-GRU-Attention鱼鹰算法优化卷积门控循环单元注意力机制多变量回归预测

高创新 | Matlab实现OOA-CNN-GRU-Attention鱼鹰算法优化卷积门控循环单元注意力机制多变量回归预测 目录 高创新 | Matlab实现OOA-CNN-GRU-Attention鱼鹰算法优化卷积门控循环单元注意力机制多变量回归预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现OOA…

css实现各级标题自动编号

本文在博客同步发布,您也可以在这里看到最新的文章 Markdown编辑器大多不会提供分级标题的自动编号功能,但我们可以通过简单的css样式设置实现。 本文介绍了使用css实现各级标题自动编号的方法,本方法同样适用于typora编辑器和wordpress主题…

Qt案例 通过调用Setupapi.h库实现对设备管理器中设备默认驱动的备份

参考腾讯电脑管家-软件市场中的驱动备份专家写的一个驱动备份软件案例,学习Setupapi.h库中的函数使用.通过Setupapi.h库读取设备管理器中安装的设备获取安装的驱动列表,通过bit7z库备份驱动目录下的所有文件. 目录导读 实现效果相关内容示例获取SP_DRVIN…

计算机网络-运输层

运输层 湖科大计算机网络 参考笔记,如有侵权联系删除 概述 运输层的任务:如何为运行在不同主机上的应用进程提供直接的通信服务 运输层协议又称端到端协议 运输层使应用进程看见的好像是在两个运输层实体之间有一条端到端的逻辑通信信道 运输层为应…

Github上传大文件(>25MB)教程

0.在github中创建新的项目(已创建可忽略这一步) 如上图所示,点击New repository 进入如下页面: 1.下载Git LFS 下载git 2.打开gitbash 3.上传文件,代码如下: cd upload #进入名为upload的文件夹,提前…

k8s集群node节点状态为Not Ready

目录 一、Node节点Not Ready状态的可能原因 二、排查node节点状态为Not Ready的原因 一、Node节点Not Ready状态的可能原因 node节点状态为Not Ready可能的原因有: 1.网络插件出问题 有过安装经验的小伙伴应该很熟悉未安装网络插件的情况下node节点在集群中的状…

【MacOs】proxychains配置使用

一、开始 1. 安装proxychains 使用brew进行安装 brew install proxychains-ng没有homebrew的,可以使用该命令安装 /usr/bin/ruby -e "$(curl -fsSL https://cdn.jsdelivr.net/gh/ineo6/homebrew-install/install)"2. 配置代理配置文件 cd /opt/homeb…

AUTOSAR配置工具开发教程 - 开篇

简介 本系列的教程,主要讲述如何自己开发一套简单的AUTOSAR ECU配置工具。适用于有C# WPF基础的人员。 简易介绍见:如何打造AUTOSAR工具_autosar_mod_ecuconfigurationparameters-CSDN博客 实现版本 AUTOSAR 4.0.3AUTOSAR 4.2.2AUTOSAR 4.4.0 效果 …

麻雀优化算法(Sparrow Search Algorithm)

注意:本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 ([www.aideeplearning.cn]) 算法背景 麻雀算法(Sparrow Search Algorithm, SSA)是一种受自然界麻雀群体行为启发的优化算法。想象一下,一…

Linux学习-网络UDP

网络 数据传输,数据共享 网络协议模型 OSI协议模型 应用层 实际发送的数据 表示层 发送的数据是否加密 会话层 是否建立会话连接 传输层 数据传输的方式(数据报、流式&#…

esp32上PWM呼吸灯

1、什么是pwm PWM(Pulse Width Modulation)简称脉宽调制,是利用微处理器的数字输出来对模拟电路进行控制的一种非常有效的技术,广泛应用在测量、通信、工控等方面。 1.1频率 单位时间内PWM方波重复的次数 1.2占空比 一个周期内…

HarmonyOS 应用开发-根据icon自适应背景颜色

介绍 本示例将介绍如何根据图片设置自适应的背景色。 效果图预览 使用说明 转换图片为PixelMap,取出所有像素值遍历所有像素值,查找到出现次数最多的像素,即为图片的主要颜色适当修改图片的主要颜色,作为自适应的背景色 实现思…

云岚到家项目

一.项目介绍 云岚到家项目是一个家政服务o2o平台,互联网家政是继打车、外卖后的又一个风口,创业者众多,比如:58到家,天鹅到家等,o2o(Online To Offline)是将线下商务的机会与互联网…

负荷预测 | Matlab基于TCN-BiGRU-Attention单输入单输出时间序列多步预测

目录 效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab基于TCN-BiGRU-Attention单输入单输出时间序列多步预测; 2.单变量时间序列数据集,采用前12个时刻预测未来96个时刻的数据; 3.excel数据方便替换,运行环境matlab…

高端大气自适应全屏酷炫渐变卡片html源码图片切换特效html5源码导航引导网站源码

源码特点: 1:手工书写DIVCSS、代码精简无冗余。 2:自适应结构,全球先进技术,高端视觉体验。 3:SEO框架布局,栏目及文章页均可独立设置标题/关键词/描述。 4:附带测试数据、安装教程、…

说说对WebSocket的理解?应用场景?

一、是什么 WebSocket,是一种网络传输协议,位于OSI模型的应用层。可在单个TCP连接上进行全双工通信,能更好的节省服务器资源和带宽并达到实时通迅 客户端和服务器只需要完成一次握手,两者之间就可以创建持久性的连接&#xff0c…