PyTorch 实战之水果分类

当我们试图提高神经网络的准确性时,经常会遇到过拟合训练数据的情况。当我们运行测试数据的模型时,这会导致一个糟糕的预测。因此,我采取了一个数据集,并应用这些技术,不仅提高准确性,而且还处理过拟合的问题。

在本文中,我们将使用以下技术在不到5分钟的时间内训练一个最先进的模型,以达到从 Fruit 360数据集中分类图像的95% 以上的准确率:

数据增强

数据分析中的数据增强是通过对现有数据或从现有数据中新创建的合成数据进行稍加修改的“副本”来增加数据量的技术。在训练机器学习模型时,它起到调节器的作用,有助于减少过拟合。

批量归一化

批量归一化是一种训练非常深入的神经网络的技术,它标准化每个小批量的输入到一个层。这有助于稳定学习过程,大大减少训练深层网络所需的训练时期。

学习率策略

学习率策略用于寻求调整学习率在训练期间通过降低学习率根据预先确定的调度器。常见的学习率策略包括基于时间的衰减,阶跃衰减和指数衰减。

权重衰减

我们使用权重衰减来保持较小的权重值,避免梯度爆炸。因为权重值将会经过L2标准化后加入到损失中,你的网络的每次迭代除了损失之外都会试图优化/最小化模型权重。这将有助于保持尽可能小的权重值,防止权重增长失控,从而避免梯度爆炸。

梯度裁剪

使用梯度裁剪可以防止梯度在神经网络中爆炸。梯度裁剪限制梯度的大小。计算梯度裁剪的方法有很多种,但一种常见的方法是重新调整梯度。

Adam 优化器

这一系列的优化器被引入来解决梯度下降法的算法问题。它们最重要的特点是不需要调整学习率值。实际上,有些库ーー例如 Keras ーー仍然可以让您手动调整它,以便进行更高级的试验。

关于数据集

水果在当今世界非常普遍ーー尽管有大量的快餐和精制食品,但水果仍然是人们广泛食用的食物。举个例子,在水果的生产过程中,可能需要对它们进行分类。传统上被机械地执行,今天,基于深度学习的技术可以增强甚至接管这个过程。

目录

  1. 引言

  2. 数据预处理

  3. 探索数据集

  4. 应用数据增强

  5. 访问少量样本图片

  6. 访问 GPU

  7. 配置模型

  8. 模型训练及成果

  9. 预测

  10. 摘要

1. 引言

图片总数:90483。

训练集大小:67692图像(每图像一个水果或蔬菜)。

测试集大小:22688图像(每张图像一个水果或蔬菜)。

类别数目:131(水果和蔬菜)。

图像大小:100x100像素。

同一种水果的不同品种(例如苹果)属于不同的类别。

2. 数据预处理

导入所需的库

因为我们使用 PyTorch 来构建神经网络,所以我一次性导入所有相关的库。

import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

 下载数据集

# Upload kaggle.jason
# please follow this link incase not aware: https://www.kaggle.com/general/74235
from google.colab import files
files.upload()

! pip install opendatasets --upgrade
import opendatasets as od

dataset_url = 'https://www.kaggle.com/moltean/fruits'
od.download(dataset_url)

在运行任何探索之前,数据集必须加载到 DataLoader。我们使用 PyTorch 的 ImageFolder 将图像加载到 DataLoader。

3. 探索数据集

问: 训练和测试数据集包含多少图像?

dataset_size = len(train_dataset)
test_dataset_size = len(test_dataset)
print(train_dataset)
print(test_dataset)

输出:

Dataset ImageFolder
    Number of datapoints: 67692
    Root location: /content/fruits/fruits-360/Training
    StandardTransform
Transform: ToTensor()
Dataset ImageFolder
    Number of datapoints: 22688
    Root location: /content/fruits/fruits-360/Test
    StandardTransform
Transform: ToTensor()

问:数据集包含多少个输出类?

# Accesssing the classes
data_dir = '/content/fruits/fruits-360/'

# print(os.listdir(data_dir))
classes = os.listdir(data_dir + "Training")
print(f'Total Number of Classe {len(classes)}')
print(f'Classes Names: {classes}')

输出:

Total Number of Classe 131
Classes Names: ['Apple Braeburn', 'Cherry Wax Red', 'Melon Piel de Sapo', 'Rambutan', 'Tamarillo', 'Pepino', 'Lemon', 'Tomato Cherry Red', 'Apple Golden 1', 'Peach Flat', 'Apple Red Delicious', 'Lemon Meyer', 'Banana Red', 'Orange', 'Peach 2', 'Pepper Red', 'Grape White', 'Kaki', 'Pepper Yellow', 'Salak', 'Potato White', 'Cucumber Ripe 2', 'Apple Golden 2', 'Pitahaya Red', 'Mulberry', 'Carambula', 'Pear Abate', 'Banana', 'Tomato Maroon', 'Pear Red', 'Pear Forelle', 'Pineapple', 'Ginger Root', 'Potato Red', 'Apple Pink Lady', 'Pear Kaiser', 'Mandarine', 'Strawberry', 'Apple Golden 3', 'Nectarine', 'Plum 3', 'Avocado ripe', 'Cantaloupe 2', 'Fig', 'Tomato 1', 'Tomato Heart', 'Passion Fruit', 'Grape Blue', 'Cantaloupe 1', 'Apple Granny Smith', 'Banana Lady Finger', 'Mango Red', 'Cherry Rainier', 'Corn Husk', 'Hazelnut', 'Pear', 'Cauliflower', 'Pear Williams', 'Tangelo', 'Avocado', 'Physalis', 'Chestnut', 'Onion White', 'Granadilla', 'Strawberry Wedge', 'Plum 2', 'Plum', 'Pepper Green', 'Tomato 3', 'Grape White 4', 'Quince', 'Maracuja', 'Apple Red 1', 'Grapefruit White', 'Cherry 1', 'Walnut', 'Grape White 2', 'Cactus fruit', 'Grape Pink', 'Potato Red Washed', 'Apple Red Yellow 1', 'Cherry 2', 'Pear 2', 'Huckleberry', 'Guava', 'Apple Red 2', 'Beetroot', 'Limes', 'Kiwi', 'Tomato 2', 'Pear Stone', 'Grapefruit Pink', 'Peach', 'Mango', 'Nut Forest', 'Cherry Wax Yellow', 'Eggplant', 'Clementine', 'Pear Monster', 'Nectarine Flat', 'Pepper Orange', 'Onion Red Peeled', 'Cocos', 'Grape White 3', 'Redcurrant', 'Dates', 'Tomato Yellow', 'Pomegranate', 'Pineapple Mini', 'Pomelo Sweetie', 'Papaya', 'Corn', 'Cucumber Ripe', 'Onion Red', 'Nut Pecan', 'Potato Sweet', 'Cherry Wax Black', 'Physalis with Husk', 'Mangostan', 'Tomato not Ripened', 'Tomato 4', 'Apricot', 'Kumquats', 'Apple Red Yellow 2', 'Kohlrabi', 'Lychee', 'Apple Crimson Snow', 'Blueberry', 'Raspberry', 'Watermelon', 'Apple Red 3']

问: 数据集中的图像张量的形状是什么?

img, label = train_dataset[0]
img_shape = img.shape
img_shape

输出:

torch.Size([3, 100, 100])

让我们打印一个示例图像及其类和标签。

img, label = train_dataset[0]
plt.imshow(img.permute((1, 2, 0)))
print('Label (numeric):', label)
print('Label (textual):', classes[label])

输出:

图片

问: 你能确定属于每个类别的图像数量吗?

dataset_size = len(train_dataset)
classes = train_dataset.classes
num_classes = len(train_dataset.classes)

img_dict = {}
for i in range(num_classes):
    img_dict[classes[i]] = 0
    
for i in range(dataset_size):
    img, label = train_dataset[i]
    img_dict[classes[label]] += 1


## Plotting classes along with images info

from matplotlib import pyplot as plt

fig, ax = plt.subplots(figsize =(16, 32))

ax.barh(list(img_dict.keys()), list(img_dict.values()))



# Add Plot Title
ax.set_title('Each Class along with their values',
             loc ='left', )

# Add annotation to bars
for i in ax.patches:
    plt.text(i.get_width()+0.2, i.get_y()+0.5,
             str(round((i.get_width()), 2)),
             fontsize = 10, fontweight ='bold',
             color ='grey')

# Add Text watermark
fig.text(0.9, 0.15, 'gurjeet333', fontsize = 12,
         color ='grey', ha ='right', va ='bottom',
         alpha = 0.7)
plt.show()

图片

4. 应用数据增强

构建数据转换

我们将首先编写转换函数,以便实现数据增强。

# Data transforms (data augmentation)
train_tfms = tt.Compose([
                         tt.RandomCrop(100, padding=4, padding_mode='reflect'), 
                         tt.Resize((100,100)),
                         tt.RandomHorizontalFlip(), 
                         # tt.RandomRotate
                         # tt.RandomResizedCrop(256, scale=(0.5,0.9), ratio=(1, 1)), 
                         # tt.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
                         tt.ToTensor(), 

                         ])
valid_tfms = tt.Compose([tt.ToTensor(), tt.Resize((100,100)),
                         ])

请随时尝试其他参数,如 tt.RandomResized,tt.ColorJitter(我在代码中添加了它们)

将转换应用于数据集

我们构建的转换需要应用于训练和测试数据集。注意:我们不在测试数据集中应用数据增强。

dataset = ImageFolder(data_dir + "/Training", transform=train_tfms)

Testing = ImageFolder(data_dir + "/Test", transform=valid_tfms)

分割数据集

我们将使用来自训练集的20%作为验证集。为了确保每次获得相同的验证集,我们将 PyTorch 的随机数生成器设置种子值为43。

torch.manual_seed(43)
val_size = round(len(dataset) * 0.2)
train_size = round(len(dataset) - val_size)

train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds)
batch_size=400

train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size*2, num_workers=4, pin_memory=True)

5. 访问样本

让我们使用来自 Torchvision 的 make_grid 函数来可视化一批数据。

def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images, nrow=16).permute(1, 2, 0))
        break

show_batch(train_loader)

图片

你能通过观察标记所有的图像吗?尝试手动标记一个随机的数据样本是一个很好的方式来估计问题的难度,并识别标记错误

6. 使用GPU

如果你的电脑有连接到 NVIDIA 生产的 GPU 上,你可以使用一个图形处理器图形处理器(GPU)来更快地训练你的模型。按照以下说明在你选择的平台上使用 GPU:

  • Google Colab:使用菜单选项“ Runtime > Change Runtime Type”,从“ Hardware Accelerator”下拉菜单中选择“ GPU”

  • Kaggle:在侧边栏的“设置”部分,从“ Accelerator”下拉菜单中选择“ GPU”,使用右上角的按钮打开侧边栏

  • Binder:运行在Binder上的代码不能使用 GPU

  • Linux:如果您的笔记本/台式机有 NVIDIA GPU (显卡) ,请确保您已经安装了 NVIDIA CUDA 驱动程序

  • Windows:如果你的笔记本/台式机有 NVIDIA GPU (显卡) ,请确保你已经安装了 NVIDIA CUDA 驱动程序。

  • macOS:macOS 与 NVIDIA GPU 不兼容

  • 如果你不能访问 GPU 或者不确定它是什么,不要担心,你可以在没有 GPU 的情况下很好地执行本教程中的所有代码

让我们从安装和导入所需的库开始。

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

现在我使用 DeviceDataLoader 函数将训练和验证集加载到 GPU 中。

device = get_default_device()
device

train_dl = DeviceDataLoader(train_loader, device)
valid_dl = DeviceDataLoader(val_loader, device)

7. 配置模型

设置精度函数和图像基类

两者都是通用函数,不需要对任何数据集进行任何更改。这些是计算精度的辅助函数,并实现损失函数来计算模型的训练和验证损失。

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['train_loss'], result['val_loss'], result['val_acc']))
class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['train_loss'], result['val_loss'], result['val_acc']))

实现批量标准化和Dropout

我们用 nn.Squential 将多层神经网络链接在一起。我在代码中添加了注释,以便简单地理解。注意——在这里我使用 nn.BatchNorm2d 在每一层的末尾实现批量归一化。

class Fruit360CnnModel(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            
            nn.Conv2d(3, 16, kernel_size=2, padding=1), 
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # 16 X 50 X 50

            nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=1), 
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # 32 X 25 X 25

            nn.Conv2d(32, 64, kernel_size=2, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(5, 5), # 64 X 5 X 5

            nn.Flatten(), 
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(64*5*5, 131))
        
    def forward(self, xb):
      return self.network(xb)

model = Fruit360CnnModel()
model

输出:

Fruit360CnnModel(
  (network): Sequential(
    (0): Conv2d(3, 16, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Dropout(p=0.3, inplace=False)
    (14): ReLU()
    (15): Linear(in_features=1600, out_features=131, bias=True)
  )
)

实现权重衰减,梯度裁剪,Adam 优化


@torch.no_grad()

def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history


# Moving the model to GPU
model = to_device(model, device)
model

输出:

Fruit360CnnModel(
  (network): Sequential(
    (0): Conv2d(3, 16, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Dropout(p=0.3, inplace=False)
    (14): ReLU()
    (15): Linear(in_features=1600, out_features=131, bias=True)
  )
)

模型在训练前的表现似乎很差。正如下面观察到的,模型的准确度低于1% ,因为模型试图随机猜测输出。

os.environ['WANDB_CONSOLE'] = 'off'

history = [evaluate(model, valid_dl)]
history

输出:

[{'val_acc': 0.005484417546540499, 'val_loss': 4.877397537231445}]

8. 模型训练和结果分析

在训练前设置参数

epochs = 4
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

运行4个epochs

%%time

os.environ['WANDB_CONSOLE'] = 'off'

history += fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, 
                             grad_clip=grad_clip, 
                             weight_decay=weight_decay, 
                             opt_func=opt_func)

输出:

Epoch [0], train_loss: 1.2414, val_loss: 0.8754, val_acc: 0.7948
Epoch [1], train_loss: 0.1211, val_loss: 0.0212, val_acc: 0.9931
Epoch [2], train_loss: 0.0101, val_loss: 0.0025, val_acc: 0.9996
Epoch [3], train_loss: 0.0049, val_loss: 0.0011, val_acc: 0.9999
CPU times: user 11.8 s, sys: 7.95 s, total: 19.8 s
Wall time: 3min 7

在4个epoch4分钟以内,我们取得了很好的精度。

Accuracy vs No

def plot_accuracies(history):
    accuracies = [x['val_acc'] for x in history]
    plt.plot(accuracies, '-x')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Accuracy vs. No. of epochs');

plot_accuracies(history)

图片

Loss vs epochs

def plot_losses(history):
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x['val_loss'] for x in history]
    plt.plot(train_losses, '-bx')
    plt.plot(val_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Loss vs. No. of epochs');

plot_losses(history)

图片

由于训练和验证的损失不是发散的,而是逐渐收敛的,这表明我们没有过度拟合我们的模型。

学习率

def plot_lrs(history):
    lrs = np.concatenate([x.get('lrs', []) for x in history])
    plt.plot(lrs)
    plt.xlabel('Batch no.')
    plt.ylabel('Learning rate')
    plt.title('Learning Rate vs. Batch no.')

图片

正如预期的那样,学习率开始于一个较低的值,并且在30% 的迭代中逐渐增加到最大值0.01,然后逐渐降低到一个非常小的值。

9. 预测

让我们在测试数据集上进行模型预测

test_tfms = tt.Compose([tt.Resize((100, 100)), 
                        tt.ToTensor()])
test_dataset = ImageFolder(data_dir + "Test", transform=test_tfms)
test_loader = DeviceDataLoader(DataLoader(test_dataset, batch_size), device)
result = evaluate(model, test_loader)
result

输出:

{'val_acc': 0.9883334040641785, 'val_loss': 0.08684124052524567}

验证准确率超过98% ,我们编写了一个辅助函数,获取一个图像并将其应用到模型中

def predict_image(img, model):
    # Convert to a batch of 1
    xb = to_device(img.unsqueeze(0), device)
    # Get predictions from model
    yb = model(xb)
    # Pick index with highest probability
    _, preds  = torch.max(yb, dim=1)
    # Retrieve the class label
    return dataset.classes[preds[0].item()]

我们现在在样本图像上测试预测

img, label = test_dataset[0]
plt.imshow(img.permute(1, 2, 0))
print('Label:', dataset.classes[label], 'Predicted:', predict_image(img, model))

输出:

图片

10. 总结

下面是本教程中用于提高模型性能和减少训练时间的不同技术的总结:

  • 数据增强:我们应用随机变换加载图像时,从训练数据集。具体来说,我们将每张图片填充4个像素,然后随机裁剪100 × 100个像素,然后以50% 的概率水平翻转图片

  • 批量归一化:在每个卷积层之后,我们增加了一个批量归一化层,对前一层的输出进行归一化处理。这有点类似于数据规范化,只不过它应用于一个层的输出,而平均值和标准差是学习参数

  • 学习率策略:不再使用固定的学习率,而是使用学习率调度器,每次训练后调整学习率。在训练过程中,有很多策略可以改变学习率,我们采用了“One Cycle Learning Rate Policy”

  • 权重衰减:我们给优化器增加了权重衰减,这是另一种正则化技术,它通过在损失函数中增加一个附加项来防止权重值变得过大

  • 梯度裁剪:我们还增加了梯度裁剪,这有助于限制梯度值在一个小的范围,以防止不良变化的模型参数

  • Adam 优化器:我们使用 Adam 优化器代替 SGD (随机梯度下降) ,该优化器使用momentum 和 自适应学习率等技术进行更快的训练。还有许多其他的优化器可以选择并进行实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/160985.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI技术:分享8个非常实用的AI绘画网站

目录 1、Midjourney 2、Stable Diffusion Omline 3、Microsoft Designer 4、Craiyon 5、NightCafe Studio 6、Wombo 7、Dalle-2 8、Avatar AI 1、Midjourney 特点:业内标杆,效果最强大 Midjourney是基于diffusion的AI图画艺术生成器。生成图片不…

OpenAI GPT-4 Turbo发布:开创AI新时代

🎥 屿小夏 : 个人主页 🔥个人专栏 : IT杂谈 🌄 莫道桑榆晚,为霞尚满天! 文章目录 📑前言一. GPT-4 Turbo的突破1.1上下文长度和控制手段的加强:1.2多模态支持&#xff1a…

vb.net U盘或移动硬盘 插入 自动(静默)复制指定格式文件

U盘或移动硬盘 插入 自动复制指定格式文件至系统盘符 开发语言:vb.net 重要申明:该程序只是防止本人不在电脑旁时,别人偷偷copy你的电脑文件,让他偷鸡成了也要蚀把米。严禁从事黑客或违反道德等不良行为,故而不发布程…

通信原理板块——奇偶监督码、方阵码、恒比码、正反码

微信公众号上线,搜索公众号小灰灰的FPGA,关注可获取相关源码,定期更新有关FPGA的项目以及开源项目源码,包括但不限于各类检测芯片驱动、低速接口驱动、高速接口驱动、数据信号处理、图像处理以及AXI总线等 1、奇偶监督码(parity check) 奇偶…

微信小程序开发---实现文件上传和下载

在开发小程序的过程中,我们难免会遇到使用小程序对后端发送文件;或者接收后端的文件,本文章将手把手带你简单高效实现微信小程序的文件上传下载功能 前期准备 由于目前小程序保护用户个人隐私力度加大 ,因此我们要想实现文件上传…

uniapp优化h5项目-摇树优化,gzip压缩和删除console.log

1.摇树优化 勾选摇树优化,打包删除死代码 2.gzip压缩和删除console.log 安装插件webpack和compression-webpack-plugin webpack插件 npm install webpack4.46.0 --save-devcompression-webpack-plugin插件 npm install compression-webpack-plugin6.1.1 --save-devconst Com…

为什么C++标准库中atomic shared_ptr不是lockfree实现?

为什么C标准库中atomic shared_ptr不是lockfree实现? 把 shared_ptr 做成 lock_free,应该是没有技术上的可行性。shared_ptr 比一个指针要大不少:最近很多小伙伴找我,说想要一些C的资料,然后我根据自己从业十年经验&am…

CVE-2021-42287CVE-2021-42278 域内提权

倘见玉皇先跪奏:他生永不落红尘 本文首发于先知社区,原创作者即是本人 前言 网络安全技术学习,承认⾃⼰的弱点不是丑事。只有对原理了然于⼼,才能突破更多的限制。拥有快速学习能力的白帽子,是不能有短板的&#xf…

四、hdfs文件系统基础操作-保姆级教程

1、启动Hadoop集群 想要使用hdfs文件系统,就先要启动Hadoop集群。 启动集群: start-dfs.sh 关闭集群: stop-dfs.sh 2、文件系统构成 (1)基础介绍 其实hdfs作为分布式存储的文件系统,其构成和Linux文件系统构成差不多一…

monaco-editor 简单使用

一. 文件调用示例 1. 安装package包 官方文档 "monaco-editor": "^0.28.1", "monaco-editor-webpack-plugin": "^4.2.0", Copy 请注意安装包的版本号 monaco-editor-webpack-pluginmonaco-editor7.*.*> 0.31.06.*.*0.30.*5.*.*…

STM32/N32G455国民科技芯片驱动DS1302时钟---笔记

这次来分享一下DS1302时钟IC,之前听说过这个IC,但是一直没搞过,用了半天时间就明白了原理和驱动,说明还是很简单的。 注:首先来区分一下DS1302和RTC时钟有什么不同,为什么不直接用RTC呢? RTC不…

基于社会群体算法优化概率神经网络PNN的分类预测 - 附代码

基于社会群体算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于社会群体算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于社会群体优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神…

秋招算法高频算法笔试题

自己在秋招过程中遇到的算法笔试题,包含中大厂,都附解析! 汽水瓶 如果汽水瓶数目为1或者0,那么一瓶都喝不到 如果汽水瓶数目为2或者3,那么只能喝到一瓶 如果为2,喝到一瓶后手里一个瓶子都没有了&#xff…

【数据结构】树与二叉树(十九):树的存储结构——左儿子右兄弟链接结构(树、森林与二叉树的转化)

文章目录 5.1 树的基本概念5.1.1 树的定义5.1.2 森林的定义5.1.3 树的术语 5.2 二叉树5.3 树5.3.1 树的存储结构1. 理论基础2. 典型实例3. Father链接结构4. 儿子链表链接结构5. 左儿子右兄弟链接结构a. 定义树节点b. 创建树节点c. 使用左儿子右兄弟链接结构将树转化为二叉树d.…

NET8 BlazorAuto渲染模式

.NET8发布后,Blazor支持四种渲染方式 静态渲染,这种页面只可显示,不提供交互,可用于网页内容展示使用Blazor Server托管的通过Server交互方式使用WebAssembly托管的在浏览器端交互方式使用Auto自动交互方式,最初使用 …

嵌入式系统中相关的高质量开源项目

关于GitHub,可能很多人误以为这是互联网人的专属,其实并不是,那上面嵌入式相关的开源项目是有很多的。现分享一些高星开源项目(像RT-Thread、AWTK等大家都熟知的就不介绍了):Avem 项目链接: ht…

【广州华锐互动VRAR】VR元宇宙技术在气象卫星知识科普中的应用

随着科技的不断发展,虚拟现实(VR)和元宇宙等技术正逐渐走进我们的生活。这些技术为我们提供了一个全新的互动平台,使我们能够以更加直观和生动的方式了解和学习各种知识。在气象天文领域,VR元宇宙技术的应用也日益显现…

TikTok与媒体素养:如何辨别虚假信息?

在当今数字时代,社交媒体平台如TikTok已经成为信息传播和社交互动的主要渠道之一。然而,随之而来的是虚假信息的泛滥,这对用户的媒体素养提出了严峻的挑战。本文将探讨TikTok平台上虚假信息的现象,以及如何提高媒体素养&#xff0…

python3:turtle绘图 .2023-11-18

绘制一个菱形:四边相等且都为200像素;四个内角两边各为60度,上下各为120度 import turtle #导入turtle #画笔默认绘制方向为水平向右 turtle.right(-30) #画笔绘制方向向左(逆时针)旋转30度. turtle.fd(200) #画笔沿绘制方向绘制200像素长度 turtle.right(60) #画笔绘制方向在…

asp.net心理健康管理系统VS开发sqlserver数据库web结构c#编程计算机网页项目

一、源码特点 asp.net 心理健康管理系统 是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 系统视频链接 https://www.bilibili.com/video/BV19w411H7P4/ 二、功能介绍 本系统使用Microsoft Visual Studio…