Colab/PyTorch - 003 Transfer Learning For Image Classification

Colab/PyTorch - 003 Transfer Learning For Image Classification

  • 1. 源由
  • 2. 迁移学习(ResNet50)
    • 2.1 数据集准备
    • 2.2 数据增强
    • 2.3 数据加载
    • 2.4 迁移学习
    • 2.5 数据集训练&验证
    • 2.6 模型推理
  • 3. 总结
  • 4. 参考资料

1. 源由

迁移学习已经彻底改变了 PyTorch 中处理图像分类的方式。

最近,PyTorch 因其易用性和学习性而受到了广泛关注。特斯拉人工智能高级总监安德烈·卡帕西在他的推特中说了以下的话。
关于 PyTorch 的安德烈·卡帕西 - 转移学习
PyTorch 非常透明,可以帮助研究人员和数据科学家实现高生产力和可靠的结果。

其实,这种方法,有点类似螺旋式上升的旋转阶梯。大量的学习让我们积累知识螺旋式上升到某个层面。当这个层面遇到新的类似问题时,只要经过适当的调整,就能更上一层楼。

而这个适当的调整并不需要大量的推倒重来,更多的是在原有的基础上进行适配调整。

在《Colab/PyTorch - 002 Pre Trained Models for Image Classification》中,放入一个没有训练过的图片,可以看到机器学习给出了一个相近的分类,俗语说“难字读半边”,大体就是这个意思。

在《Jammy@Jetson Orin - Tensorflow & Keras Get Started: Transfer Learning & Fine Tuning》,也讨论过关于Transfer Learning的概念。

2. 迁移学习(ResNet50)

基于ImageNet用数百万张图像进行了训练的ResNet50模型,对CalTech256数据集的子集来对10种动物的图像进行分类就是本次PyTorch的一个例子。

通过据集准备、数据增强、构建分类器、使用迁移学习来利用低级图像特征,比如边缘、纹理等模型参数特征(这些特征是预训练模型ResNet50继承得来的)。最后,通过训练集来进一步训练分类器,从而学习数据集图像中的更高级别的细节,比如眼睛、腿等。

2.1 数据集准备

CalTech256数据集包含30,607张图像,分为256个不同的标签类别,还有一个“杂乱”类别。对整个数据集进行训练需要数小时。

因此,我们将使用数据集的一个子集,其中包含10种动物:熊、黑猩猩、长颈鹿、大猩猩、美洲驼、鸵鸟、豪猪、臭鼬、三角龙和斑马。这些文件夹中的图像数量从81张(对于臭鼬)到212张(对于大猩猩)不等。

我们对这些图片进行如下分类:

  • 每个类别中的前60张图像进行训练
  • 接下来的10张图像用于验证
  • 其余的图像用于下面实验中的测试

因此,最终,我们有600张训练图像、100张验证图像、409张测试图像和10个动物类别。

操作步骤如下:

  1. 下载 CalTech256数据集
  2. 创建名为 train、valid 和 test 的三个目录。
  3. 在 train/valid/test 目录中分别创建 10 个子目录。这些子目录应该命名为 bear、chimp、giraffe、gorilla、llama、ostrich、porcupine、skunk、triceratops 和 zebra。
  4. 将 Caltech256 数据集中的前60张熊的图像移动到目录 train/bear。对每种动物重复此步骤。
  5. 将 Caltech256 数据集中的下一组10张熊的图像移动到目录 valid/bear。对每种动物重复此步骤。
  6. 将熊的剩余图像(即未包含在 train 或 valid 文件夹中的图像)复制到目录 test/bear。对每种动物重复此步骤。

2.2 数据增强

在每个 epoch 中,每个输入图像会经过多次变换,通过在变换中引入一些随机性来插入一些变化。

当训练多个 epoch 时,模型会看到输入图像都是新的随机变换。这导致了数据增强,然后模型试图更加泛化。

下面看到了一张三角龙图像的变换版本的示例:
在这里插入图片描述

# Applying Transforms to the Data
image_transforms = { 
    'train': transforms.Compose([
        transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}
  • transform.RandomResizedCrop 通过随机大小裁剪输入图像(在原始大小的0.8到1.0的范围内,并在默认范围内的0.75到1.33的随机宽高比)。然后将裁剪后的图像调整大小为256×256。
  • transform.RandomRotation 将图像随机旋转一个角度,角度范围为-15度到15度。
  • transform.RandomHorizontalFlip 以默认概率50%随机水平翻转图像。
  • transform.CenterCrop 从中心裁剪出一个224×224的图像。
  • transform.ToTensor 将 PIL 图像转换为浮点数张量,并将值归一化到0-1的范围,方法是将其除以255。
  • transform.Normalize 接受一个3通道张量,并按照该通道的输入均值和标准差对每个通道进行归一化。均值和标准差向量作为3元素向量输入。张量中的每个通道被规范化为 T = (T - 均值) / 标准差。

所有上述转换都使用train的transforms.Compose连接在一起。

对于验证和测试数据,我们不执行RandomResizedCrop、RandomRotation和RandomHorizontalFlip转换。相反,我们将验证图像调整大小为256×256,并裁剪出中心224×224的部分,以便能够将它们与预训练模型一起使用。最后,图像被转换为张量,并按照ImageNet中所有图像的均值和标准差进行归一化处理。

2.3 数据加载

Colab上运行,需要将制作好的数据集上传Google云存储。
在这里插入图片描述

数据存放目录如下所示:

├── caltech_10
│   ├── test
│   │   ├── bear
│   │   ├── chimp
│   │   ├── giraffe
│   │   ├── gorilla
│   │   ├── llama
│   │   ├── ostrich
│   │   ├── porcupine
│   │   ├── skunk
│   │   ├── triceratops
│   │   └── zebra
│   ├── train
│   │   ├── bear
│   │   ├── chimp
│   │   ├── giraffe
│   │   ├── gorilla
│   │   ├── llama
│   │   ├── ostrich
│   │   ├── porcupine
│   │   ├── skunk
│   │   ├── triceratops
│   │   └── zebra
│   └── valid
│       ├── bear
│       ├── chimp
│       ├── giraffe
│       ├── gorilla
│       ├── llama
│       ├── ostrich
│       ├── porcupine
│       ├── skunk
│       ├── triceratops
│       └── zebra
└── image_classification_using_transfer_learning_in_pytorch.ipynb

使用DataLoader加载用于训练的数据:

# Test on Google Drive

from google.colab import drive
drive.mount('/content/drive')

# Load the Data

# Set train and valid directory paths

dataset = '/content/drive/MyDrive/caltech_10'

train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'valid')
test_directory = os.path.join(dataset, 'test')

# Batch size
bs = 32

# Number of classes
num_classes = len(os.listdir(valid_directory))  #10#2#257
print(num_classes)

# Load Data from folders
data = {
    'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
    'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid']),
    'test': datasets.ImageFolder(root=test_directory, transform=image_transforms['test'])
}

# Get a mapping of the indices to the class names, in order to see the output classes of the test images.
idx_to_class = {v: k for k, v in data['train'].class_to_idx.items()}
print(idx_to_class)

# Size of Data, to be used for calculating Average Loss and Accuracy
train_data_size = len(data['train'])
valid_data_size = len(data['valid'])
test_data_size = len(data['test'])

# Create iterators for the Data loaded using DataLoader module
train_data_loader = DataLoader(data['train'], batch_size=bs, shuffle=True)
valid_data_loader = DataLoader(data['valid'], batch_size=bs, shuffle=True)
test_data_loader = DataLoader(data['test'], batch_size=bs, shuffle=True)

在这里插入图片描述

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(train_data_size, valid_data_size, test_data_size)

在这里插入图片描述

2.4 迁移学习

从头开始训练分类器是非常困难且耗时的。

因此,使用一个预训练的模型作为基础,并改变最后几层以根据我们想要的类别对图像进行分类。这有助于我们在一个小数据集上获得良好的结果,因为预训练模型已经从像ImageNet这样的更大数据集中学习了基本的图像特征。

在这里插入图片描述
正如上面图像中所看到的,内部层保持与预训练ResNet50模型相同,只有最后几层被更改以适应我们的类别数量。

# Load pretrained ResNet50 Model
resnet50 = models.resnet50(pretrained=True)
resnet50 = resnet50.to(device)

Canziani等人列出了许多预训练模型,用于各种实际应用,分析了每个模型的准确性和推断所需的时间。

ResNet50是一个在准确性和推断时间之间有良好折衷的模型之一。当在PyTorch中加载模型时,默认情况下,所有参数的’requires_grad’字段都设置为true。这意味着对参数值的每一次更改都将被存储以用于训练时使用的反向传播图中。这增加了内存需求。由于我们预训练模型中的大多数参数已经训练过,我们将’requires_grad’字段重置为false。

# Freeze model parameters
for param in resnet50.parameters():
    param.requires_grad = False

接下来,我们将ResNet50模型的最后一层替换为一小组Sequential层。ResNet50的最后一个全连接层的输入被馈送到一个线性层。它有256个输出,然后被馈送到ReLU和Dropout层。然后是一个256×10的线性层,它有10个输出,对应于我们CalTech子集中的10个类别。

# Change the final layer of ResNet50 Model for Transfer Learning
fc_inputs = resnet50.fc.in_features

resnet50.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes), # Since 10 possible outputs
    nn.LogSoftmax(dim=1) # For using NLLLoss()
)

由于我们将在GPU上进行训练,我们准备好将模型移到GPU上。

# Convert model to be used on GPU
resnet50 = resnet50.to(device)

接下来,我们定义用于训练的损失函数和优化器。PyTorch提供了多种损失函数。我们使用负对数似然损失函数,因为它对于多类别分类很有用。PyTorch还支持多种优化器。我们使用Adam优化器。Adam是最流行的优化器之一,因为它可以针对每个参数单独调整学习率。

# Define Optimizer and Loss Function
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())

2.5 数据集训练&验证

训练是在固定的epoch数量内进行的,每个图像在单个epoch中处理一次。训练数据加载器以批量方式加载数据。我们给定了一个批量大小为32。这意味着每个批次最多可以有32张图像。

对于每个批次,输入图像被传递到模型中,即前向传播,以获取输出。然后使用提供的损失函数或成本函数使用真实值和计算得到的输出来计算损失。使用反向函数计算损失相对于可训练参数的梯度。

注:使用迁移学习时,只需要计算属于模型末尾的少数新添加层的一小部分参数的梯度。对模型的摘要函数调用可以显示实际参数数量和可训练参数的数量。这种方法中的优势是只需要训练总模型参数的大约十分之一。

梯度计算是使用自动求导和反向传播完成的,在图中使用链式法则进行微分。PyTorch在反向传播过程中累积所有梯度。因此,在训练循环的开始时将它们清零是至关重要的。这可以通过优化器的zero_grad函数实现。最后,在反向传播过程中计算了梯度后,使用优化器的step函数更新参数。

对整个批次计算总损失和准确度,然后对所有批次进行平均,以获得整个epoch的损失和准确度值。

随着训练进行的增加,模型倾向于过度拟合数据,导致其在新的测试数据上表现不佳。保持一个单独的验证集很重要,这样我们就可以在合适的时机停止训练,并防止过拟合。在每个epoch的训练循环之后立即进行验证。由于在验证过程中我们不需要进行任何梯度计算,所以它是在一个torch.no_grad()块内完成的。

对于每个验证批次,输入和标签被传输到GPU(如果cuda可用,否则它们被传输到CPU)。输入经过前向传播,然后对批次和循环结束时的整个epoch进行损失和准确度计算。

def train_and_validate(model, loss_criterion, optimizer, epochs=25):
    '''
    Function to train and validate
    Parameters
        :param model: Model to train and validate
        :param loss_criterion: Loss Criterion to minimize
        :param optimizer: Optimizer for computing gradients
        :param epochs: Number of epochs (default=25)
  
    Returns
        model: Trained Model with best validation accuracy
        history: (dict object): Having training loss, accuracy and validation loss, accuracy
    '''
    
    start = time.time()
    history = []
    best_loss = 100000.0
    best_epoch = None

    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))
        
        # Set to training mode
        model.train()
        
        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0
        
        valid_loss = 0.0
        valid_acc = 0.0
        
        for i, (inputs, labels) in enumerate(train_data_loader):

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Clean existing gradients
            optimizer.zero_grad()
            
            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)
            
            # Compute loss
            loss = loss_criterion(outputs, labels)
            
            # Backpropagate the gradients
            loss.backward()
            
            # Update the parameters
            optimizer.step()
            
            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)
            
            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            
            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * inputs.size(0)
            
            #print("Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}".format(i, loss.item(), acc.item()))

        
        # Validation - No gradient tracking needed
        with torch.no_grad():

            # Set to evaluation mode
            model.eval()

            # Validation loop
            for j, (inputs, labels) in enumerate(valid_data_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass - compute outputs on input data using the model
                outputs = model(inputs)

                # Compute loss
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to valid_loss
                valid_loss += loss.item() * inputs.size(0)

                # Calculate validation accuracy
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to valid_acc
                valid_acc += acc.item() * inputs.size(0)

                #print("Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))
        if valid_loss < best_loss:
            best_loss = valid_loss
            best_epoch = epoch

        # Find average training loss and training accuracy
        avg_train_loss = train_loss/train_data_size 
        avg_train_acc = train_acc/train_data_size

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss/valid_data_size 
        avg_valid_acc = valid_acc/valid_data_size

        history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
                
        epoch_end = time.time()
    
        print("Epoch : {:03d}, Training: Loss - {:.4f}, Accuracy - {:.4f}%, \n\t\tValidation : Loss - {:.4f}, Accuracy - {:.4f}%, Time: {:.4f}s".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))
        
        # Save if the model has best accuracy till now
        torch.save(model, dataset+'_model_'+str(epoch)+'.pt')
            
    return model, history, best_epoch
Epoch: 1/30
Epoch : 000, Training: Loss - 1.6488, Accuracy - 46.3333%, 
		Validation : Loss - 0.6532, Accuracy - 94.0000%, Time: 69.4381s
Epoch: 2/30
Epoch : 001, Training: Loss - 0.5886, Accuracy - 87.1667%, 
		Validation : Loss - 0.3524, Accuracy - 89.0000%, Time: 8.7085s
Epoch: 3/30
Epoch : 002, Training: Loss - 0.3438, Accuracy - 92.1667%, 
		Validation : Loss - 0.2165, Accuracy - 97.0000%, Time: 9.7899s
Epoch: 4/30
Epoch : 003, Training: Loss - 0.2306, Accuracy - 94.8333%, 
		Validation : Loss - 0.2157, Accuracy - 93.0000%, Time: 8.2054s
Epoch: 5/30
Epoch : 004, Training: Loss - 0.1882, Accuracy - 95.5000%, 
		Validation : Loss - 0.1668, Accuracy - 96.0000%, Time: 9.5612s
Epoch: 6/30
Epoch : 005, Training: Loss - 0.1682, Accuracy - 95.3333%, 
		Validation : Loss - 0.1474, Accuracy - 97.0000%, Time: 8.0876s
Epoch: 7/30
Epoch : 006, Training: Loss - 0.1850, Accuracy - 94.8333%, 
		Validation : Loss - 0.1665, Accuracy - 96.0000%, Time: 8.8183s
Epoch: 8/30
Epoch : 007, Training: Loss - 0.1645, Accuracy - 96.5000%, 
		Validation : Loss - 0.1560, Accuracy - 94.0000%, Time: 8.5604s
Epoch: 9/30
Epoch : 008, Training: Loss - 0.1490, Accuracy - 94.3333%, 
		Validation : Loss - 0.1529, Accuracy - 95.0000%, Time: 8.7029s
Epoch: 10/30
Epoch : 009, Training: Loss - 0.1236, Accuracy - 95.8333%, 
		Validation : Loss - 0.1578, Accuracy - 96.0000%, Time: 8.4867s
Epoch: 11/30
Epoch : 010, Training: Loss - 0.0958, Accuracy - 97.8333%, 
		Validation : Loss - 0.2119, Accuracy - 92.0000%, Time: 9.2300s
Epoch: 12/30
Epoch : 011, Training: Loss - 0.1213, Accuracy - 96.5000%, 
		Validation : Loss - 0.1623, Accuracy - 95.0000%, Time: 8.8091s
Epoch: 13/30
Epoch : 012, Training: Loss - 0.1710, Accuracy - 94.0000%, 
		Validation : Loss - 0.2524, Accuracy - 89.0000%, Time: 9.2230s
Epoch: 14/30
Epoch : 013, Training: Loss - 0.1252, Accuracy - 95.6667%, 
		Validation : Loss - 0.1514, Accuracy - 96.0000%, Time: 8.6822s
Epoch: 15/30
Epoch : 014, Training: Loss - 0.0881, Accuracy - 97.6667%, 
		Validation : Loss - 0.1434, Accuracy - 95.0000%, Time: 8.8090s
Epoch: 16/30
Epoch : 015, Training: Loss - 0.0828, Accuracy - 96.8333%, 
		Validation : Loss - 0.1271, Accuracy - 96.0000%, Time: 8.4651s
Epoch: 17/30
Epoch : 016, Training: Loss - 0.0705, Accuracy - 97.8333%, 
		Validation : Loss - 0.1536, Accuracy - 96.0000%, Time: 9.4177s
Epoch: 18/30
Epoch : 017, Training: Loss - 0.0834, Accuracy - 97.6667%, 
		Validation : Loss - 0.1726, Accuracy - 94.0000%, Time: 8.0583s
Epoch: 19/30
Epoch : 018, Training: Loss - 0.0740, Accuracy - 98.1667%, 
		Validation : Loss - 0.1585, Accuracy - 95.0000%, Time: 9.0694s
Epoch: 20/30
Epoch : 019, Training: Loss - 0.0659, Accuracy - 98.3333%, 
		Validation : Loss - 0.2178, Accuracy - 93.0000%, Time: 8.6098s
Epoch: 21/30
Epoch : 020, Training: Loss - 0.0819, Accuracy - 97.8333%, 
		Validation : Loss - 0.1866, Accuracy - 95.0000%, Time: 8.9363s
Epoch: 22/30
Epoch : 021, Training: Loss - 0.0921, Accuracy - 96.5000%, 
		Validation : Loss - 0.1907, Accuracy - 95.0000%, Time: 8.2608s
Epoch: 23/30
Epoch : 022, Training: Loss - 0.0662, Accuracy - 98.0000%, 
		Validation : Loss - 0.1494, Accuracy - 95.0000%, Time: 11.4841s
Epoch: 24/30
Epoch : 023, Training: Loss - 0.0618, Accuracy - 98.0000%, 
		Validation : Loss - 0.1616, Accuracy - 94.0000%, Time: 9.5806s
Epoch: 25/30
Epoch : 024, Training: Loss - 0.0767, Accuracy - 97.0000%, 
		Validation : Loss - 0.1904, Accuracy - 94.0000%, Time: 8.1954s
Epoch: 26/30
Epoch : 025, Training: Loss - 0.0476, Accuracy - 98.8333%, 
		Validation : Loss - 0.1830, Accuracy - 94.0000%, Time: 9.1184s
Epoch: 27/30
Epoch : 026, Training: Loss - 0.0575, Accuracy - 98.3333%, 
		Validation : Loss - 0.2433, Accuracy - 94.0000%, Time: 8.7180s
Epoch: 28/30
Epoch : 027, Training: Loss - 0.0831, Accuracy - 97.1667%, 
		Validation : Loss - 0.2413, Accuracy - 93.0000%, Time: 8.5477s
Epoch: 29/30
Epoch : 028, Training: Loss - 0.0649, Accuracy - 98.0000%, 
		Validation : Loss - 0.2413, Accuracy - 93.0000%, Time: 9.0394s
Epoch: 30/30
Epoch : 029, Training: Loss - 0.0647, Accuracy - 97.5000%, 
		Validation : Loss - 0.1586, Accuracy - 94.0000%, Time: 9.5677s

在这里插入图片描述在这里插入图片描述
如上图所示,对于这个数据集,验证和训练损失都相当快地稳定下来。准确度也迅速提高到了0.9左右的范围。随着epoch数量的增加,训练损失进一步减小,导致过拟合,但验证结果并没有显著改善。因此,我们选择了具有更高准确度和较低损失的epoch的模型。最好是在早期停止以防止过拟合训练数据。在我们的情况下,我们选择了第8个epoch,其验证准确度为96%。

2.6 模型推理

一旦我们有了模型,我们就可以对单个测试图像或整个测试数据集进行推断,以获取测试准确度。测试集准确度的计算与验证代码类似,只是它是在测试数据集上进行的。我们已经在Python笔记本中包含了computeTestSetAccuracy函数来完成相同的工作。让我们在下面讨论如何为给定的测试图像找到输出类别。

首先,输入图像经历用于验证/测试数据的所有转换。然后将结果张量转换为四维张量,并通过模型传递,该模型输出不同类别的对数概率。对模型输出的指数给出了类别概率。然后我们选择具有最高概率的类作为我们的输出类别。选择具有最高概率的类作为我们的输出类别。

def predict(model, test_image_name):
    '''
    Function to predict the class of a single test image
    Parameters
        :param model: Model to test
        :param test_image_name: Test image

    '''
    
    transform = image_transforms['test']


    test_image = Image.open(test_image_name)
    plt.imshow(test_image)
    
    test_image_tensor = transform(test_image)
    if torch.cuda.is_available():
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()
    else:
        test_image_tensor = test_image_tensor.view(1, 3, 224, 224)
    
    with torch.no_grad():
        model.eval()
        # Model outputs log probabilities
        out = model(test_image_tensor)
        ps = torch.exp(out)

        topk, topclass = ps.topk(3, dim=1)
        cls = idx_to_class[topclass.cpu().numpy()[0][0]]
        score = topk.cpu().numpy()[0][0]

        for i in range(3):
            print("Predcition", i+1, ":", idx_to_class[topclass.cpu().numpy()[0][i]], ", Score: ", topk.cpu().numpy()[0][i])
# Test a particular model on a test image
! wget https://cdn.pixabay.com/photo/2018/10/01/12/28/skunk-3716043_1280.jpg -O skunk.jpg
dataset = '/content/drive/MyDrive/caltech_10'
model = torch.load("{}_model_{}.pt".format(dataset, best_epoch))
predict(model, 'skunk.jpg')

# Load Data from folders
#computeTestSetAccuracy(model, loss_func)

在这里插入图片描述判断的准确率99.97%。

3. 总结

使用一个在ImageNet的1000个类别上预训练的模型,非常有效地应用了ResNet50模型,通过小数据集的训练,对感兴趣的10个不同类别的图像进行了分类。

测试代码:003 Image Classification using Transfer Learning in Pytorch

4. 参考资料

【1】Colab/PyTorch - Getting Started with PyTorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/614588.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

搜歌网搜索各种类型音乐,统统歌曲转换格式mp3,轻松实现音乐自由!

在互联网的广阔天地中&#xff0c;音乐爱好者们总能找到满足自己需求的平台。其中&#xff0c;支持全网搜歌的网站无疑是一个值得推荐的音乐探索乐园。无论是寻找经典老歌&#xff0c;还是发掘新兴音乐&#xff0c;搜他们都能为音乐爱好者提供一站式的服务。 一般支持全网搜索…

微信投票源码系统至尊版 吸粉变现功能二合一

源码简介 微信投票系统在营销和社交互动中发挥着多方面的作用&#xff0c;它能够提升用户的参与度和品牌曝光度&#xff0c;还是一种有效的数据收集、营销推广和民主决策工具。 分享一款微信投票源码系统至尊版&#xff0c;集吸粉变现功能二合一&#xff0c;全网独家支持礼物…

py黑帽子学习笔记_网络编程工具

tcp客户端 socket.AF_INET表示使用标准IPV4地址和主机名 SOCK_STREAM表示这是一个TCP客户端 udp客户端 udp无需连接&#xff0c;因此不需要client.connect这种代码 socket.SOCK_DGRAM是udp的 tcp服务端 server.listen(5)表示设置最大连接数为5 发现kill server后端口仍占用…

【论文阅读笔记】jTrans(ISSTA 22)

个人博客地址 [ISSTA 22] jTrans&#xff08;个人阅读笔记&#xff09; 论文&#xff1a;《jTrans: Jump-Aware Transformer for Binary Code Similarity》 仓库&#xff1a;https://github.com/vul337/jTrans 提出的问题 二进制代码相似性检测&#xff08;BCSD&#xff0…

Stable Diffusion的技术原理

一、Stable Diffusion的技术原理 Stable Diffusion是一种基于Latent Diffusion Models&#xff08;LDMs&#xff09;实现的文本到图像&#xff08;text-to-image&#xff09;生成模型。其工作原理主要基于扩散过程&#xff0c;通过模拟数据的随机演化行为&#xff0c;实现数据…

壹资源知识付费系统源码-小程序端+pc端

最新整理优化&#xff0c;含微信小程序和pc网页。内置几款主题&#xff0c;并且可以自己更改主题样式&#xff0c;各区块颜色&#xff0c;文字按钮等。 适用于知识付费类资源类行业。如&#xff1a;项目类&#xff0c;小吃技术类&#xff0c;图书类&#xff0c;考研资料类&…

修改表空间对应数据文件的大小

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 表空间与数据文件紧密相连&#xff0c;相互依存&#xff0c;创建表空间的时候需设置数据文件大小。 在后期实际应用中&#xff0c;如果实际存储的数据量超出事先设置的数据…

本地运行.net项目

有时候需要我们自己做一个.net的课设项目&#xff0c;但是我们有了代码后却不知道怎么运行。我们0基础来学习一下如何运行一个.net项目 1.安装visual studio 2022 不用安装老版本&#xff0c;新版就可以。安装好了2022版本&#xff0c;这是一个支持web的IDE&#xff0c;我们可…

Java----数组的定义和使用

1.数组的定义 在Java中&#xff0c;数组是一种相同数据类型的集合。数组在内存中是一段连续的空间。 2.数组的创建和初始化 2.1数组的创建 在Java中&#xff0c;数组创建的形式与C语言又所不同。 Java中数组创建的形式 T[] 数组名 new T[N]; 1.T表示数组存放的数据类型…

ARM单片机实现流水灯(GD32)

根据上图可知使用的引脚分别是PA8,PE6,PF6流水灯功能的实现要分别初始化这几个引脚 流水灯实现 编写流水灯代码 LED.C #include "gd32f30x.h" // Device header #include "Delay.h" // 初始化LED灯 void LED_Init(void){// 使能RCU时钟…

词令蚂蚁庄园今日答案如何在微信小程序查看蚂蚁庄园今天问题的正确答案?

词令蚂蚁庄园今日答案如何在微信小程序查看蚂蚁庄园今天问题的正确答案&#xff1f; 1、打开微信&#xff0c;点击搜索框&#xff1b; 2、打开搜索页面&#xff0c;选择小程序搜索&#xff1b; 3、在搜索框&#xff0c;输入词令搜索点击进入词令微信小程序&#xff1b; 4、打开…

1290.二进制链表转整数

给你一个单链表的引用结点 head。链表中每个结点的值不是 0 就是 1。已知此链表是一个整数数字的二进制表示形式。 请你返回该链表所表示数字的 十进制值 。 示例 1&#xff1a; 输入&#xff1a;head [1,0,1] 输出&#xff1a;5 解释&#xff1a;二进制数 (101) 转化为十进制…

三大消息传递机制区别与联系

目录 总结放开头 1、定义区别&#xff1a; EventBus Broadcast Receiver Notification 2、使用区别: EventBus Broadcast Receiver Notification 3、补充通知渠道&#xff1a; 通知渠道重要程度 总结放开头 BroadCast Receiver:属于安卓全局监听机制&#xff0c;接收…

软件开发项目实施方案-精华资料(Word原件)

依据项目建设要求&#xff0c;对平台进行整体规划设计更新维护&#xff0c;对系统运行的安全性、可靠性、易用性以及稳健性进行全新设计&#xff0c;并将所有的应用系统进行部署实施和软件使用培训以及技术支持。 根据施工总进度规划&#xff0c;编制本项目施工进度计划表。依据…

卷积特征图与感受野

特征图尺寸和感受野是卷积神经网络中非常重要的两个概念&#xff0c;今天来看一下&#xff0c;如何计算特征尺寸和感受野。 特征图尺寸 卷积特征图&#xff0c;是图片经过卷积核处理之后的尺寸。计算输出特征的尺寸&#xff0c;需要给出卷积核的相关参数包括&#xff1a; 输…

【C++泛型编程】(二)标准模板库 STL

文章目录 标准模板库 STL容器算法迭代器仿函数/函数对象适配器分配器示例 标准模板库 STL C 的标准模板库&#xff08;Standard Template Library&#xff0c;STL&#xff09;旨在通过模板化的设计&#xff0c;提供一种通用的编程模式&#xff0c;使程序员能方便地实现和扩展各…

栈实现队列

一、分析 栈的特点是先出再入&#xff0c;而队列的特点为先入先出&#xff0c;所以我们创造两个栈&#xff0c;一个用来存放数据&#xff0c;一个用来实现其它功能此时栈顶为队尾&#xff1b;当要找队头数据时将前n-1个数据移入到另一个栈中&#xff0c;此时剩余那个数据为队头…

通义灵码企业版正式发布,满足企业私域知识检索、数据合规、统一管理等需求

5 月 9 日阿里云 AI 峰会&#xff0c;阿里云智能集团首席技术官周靖人宣布&#xff0c;通义灵码企业版正式发布&#xff0c;满足企业用户的定制化需求&#xff0c;帮助企业提升研发效率。 通义灵码是国内用户规模第一的智能编码助手&#xff0c;基于 SOTA 水准的通义千问代码模…

永久免费的多域名通配符SSL证书申请流程

如果拥有多个域名&#xff0c;且有部分域名拥有子域名&#xff0c;那么多域名通配符证书是非常合适的选择。预算有限或者前期测试可以考虑免费版本的&#xff0c;国产证书厂商JoySSL则提供免费的多域名通配符证书 。 具体流程如下 1创建管理账号 登录JoySSL官网&#xff0c;创…

小程序地理位置接口申请教程来啦4步学会

小程序地理位置接口有什么功能&#xff1f; 如果我们提审后驳回理由写了“当前提审小程序代码包中地理位置相关接口( chooseAddress、getLocation )暂未开通&#xff0c;建议完成接口开通后或移除接口相关内容后再进行后续版本提审”&#xff0c;如果你也碰到类似问题&#xff…