RESNET的复现pytorch版本
使用的数据为Object_102_CaDataset,可以在网上下载,也可以在评论区问。
RESNET模型的亮点
1.提出了残差模块。
2.使用Batch Normalization加速训练
3.残差网络:易于收敛,很好的解决了退化问题,模型可以很深,准确率大大提高了。
残差结构如下所示:
首先,是模型构建部分
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride_1=1, stride_2=1, padding=1, kernel_size=(3, 3), short_cut=None):
super(ResBlock, self).__init__()
self.short_cut = short_cut
self.model = Sequential(
# 1.1
Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride_1,
padding=padding),
BatchNorm2d(out_channels),
ReLU(),
Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride_2,
padding=padding),
BatchNorm2d(out_channels),
ReLU(),
)
self.short_layer = Sequential(
Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), stride=2, padding=0),
BatchNorm2d(out_channels),
ReLU(),
)
self.R = ReLU()
def forward(self, x):
f1 = x
if self.short_cut is not None:
f1 = self.short_layer(x)
out = self.model(x)
out = self.R(f1+out)
return out
该部分为模型的残差块,使用了3*3的卷积,然后进行归一化。
对于整个模型的构建部分:
class Resnet_easier(nn.Module):
def __init__(self, num_classes):
super(Resnet_easier, self).__init__()
self.model0 = Sequential(
# 0
# 输入3通道、输出64通道、卷积核大小、步长、补零、
Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=2, padding=3),
BatchNorm2d(64),
ReLU(),
MaxPool2d(kernel_size=(3, 3), stride=2, padding=1),
)
self.model1 = ResBlock(64, 64)
self.model2 = ResBlock(64, 64)
self.model3 = ResBlock(64, 128, stride_1=2, stride_2=1, short_cut=True)
self.model4 = ResBlock(128, 128)
self.model5 = ResBlock(128, 256, stride_1=2, stride_2=1, short_cut=True)
self.model6 = ResBlock(256, 256)
self.model7 = ResBlock(256, 512, stride_1=2, stride_2=1, short_cut=True)
self.model8 = ResBlock(512, 512)
# AAP 自适应平均池化
self.aap = AdaptiveAvgPool2d((1, 1))
# flatten 维度展平
self.flatten = Flatten(start_dim=1)
# FC 全连接层
self.fc = Linear(512, num_classes)
def forward(self, x):
x = x.to(torch.float32)
x = self.model0(x)
x = self.model1(x)
x = self.model2(x)
x = self.model3(x)
x = self.model4(x)
x = self.model5(x)
x = self.model6(x)
x = self.model7(x)
x = self.model8(x)
# 最后3个
x = self.aap(x)
x = self.flatten(x)
x = self.fc(x)
return x
接下来是读入数据模块
class Object_102_CaDataset(Dataset):
def __init__(self, folder):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
self.file_list = []
label_names = [item for item in os.listdir(folder) if os.path.isdir(os.path.join(folder, item))] # 获取文件夹下的所有标签
label_to_index = dict((label, index) for index, label in enumerate(label_names)) # 将label转为数字
self.all_picture_paths = self.get_all_picture(folder) # 获取所有图片路径
self.all_picture_labels = [label_to_index[os.path.split(os.path.dirname(os.path.abspath(path)))[1]] for path in
self.file_list]
self.mean = np.array(mean).reshape((1, 1, 3))
self.std = np.array(std).reshape((1, 1, 3))
def __getitem__(self, index):
img = cv2.imread(self.all_picture_paths[index])
if img is None:
print(os.path.join("image", self.all_picture_paths[index]))
img = cv2.resize(img, (224, 224)) #统一图片的尺寸
img = img / 255
img = (img - self.mean) / self.std
img = np.transpose(img, [2, 0, 1])
label = self.all_picture_labels[index]
img = torch.tensor(img)
label = torch.tensor(label)
return img, label
def __len__(self):
return len(self.all_picture_paths)
def get_all_picture(self, folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
if os.path.isfile(file_path):
self.file_list.append(file_path)
elif os.path.isdir(file_path):
self.file_list = self.get_all_picture(file_path)
return self.file_list
使用上述dataloader可以方便的对数据进行读取操作。
接下来就是整个的训练模块
import torch
from torch import nn
from torch.utils.data import DataLoader
from ResNet.ResNet18 import Resnet18
from ResNet.ResNet18_easier import Resnet_easier
from ResNet.dataset import Object_102_CaDataset
from ResNet.res_net import ResNet, ResBlock
from torchsummary import summary
data_dir = 'E:\PostGraduate\Paper_review\computer_view_model/ResNet/data/101_ObjectCategories'
Object_102 = Object_102_CaDataset(data_dir)
train_size = int(len(Object_102) * 0.7)
# print(train_size)
test_size = len(Object_102) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(Object_102, [train_size, test_size])
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
#显示数据,此处的注释内容可以让我们看到读取的图片
# import random
# from matplotlib import pyplot as plt
# import matplotlib
# matplotlib.use('TkAgg')
# def denorm(img):
# for i in range(img.shape[0]):
# img[i] = img[i] * std[i] + mean[i]
# img = torch.clamp(img, 0., 1.)
# return img
# plt.figure(figsize=(8, 8))
# for i in range(9):
# img, label = train_dataset[random.randint(0, len(train_dataset))]
# img = denorm(img)
# img = img.permute(1, 2, 0)
# ax = plt.subplot(3, 3, i + 1)
# ax.imshow(img.numpy()[:, :, ::-1])
# ax.set_title("label = %d" % label)
# ax.set_xticks([])
# ax.set_yticks([])
# plt.show()
train_iter = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_iter = DataLoader(train_dataset, batch_size=64)
model = Resnet_easier(102)
# print(summary(model, (3, 224, 224)))
epoch = 50 # 训练轮次
optmizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# optmizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()#.cuda() # 定义交叉熵损失函数
log_interval = 10
train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_iter.dataset) for i in range(epoch + 1)]
# test_loop(model,'cpu',test_iter)
def train_loop(n_epochs, optimizer, model, loss_fn, train_loader):
for epoch in range(1, n_epochs + 1):
model.train()
for i, data in enumerate(train_loader):
correct = 0
(images, label) = data
images = images#.cuda()
label = label#.cuda()
# print(len(images))
output = model(images)
loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = output.data.max(1, keepdim=True)[1]
pred = torch.tensor(pred, dtype=torch.float32)
for index in range(0, len(pred)):
if pred[index] == label[index]:
correct += 1
# correct = torch.eq(pred, label).sum()
# print(correct)
if i % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\t accuracy:[{}/{} ({:.0f}%)] \tLoss: {:.6f}'.format(
epoch, i * len(images), len(train_loader.dataset),
100. * i / len(train_loader), correct, len(pred), 100. * correct / len(pred), loss.item()))
train_losses.append(loss.item())
train_counter.append(
(i * 64) + ((epoch - 1) * len(train_loader.dataset)))
torch.save(model.state_dict(), 'model_paramter/test/model.pth')
torch.save(optimizer.state_dict(), 'model_paramter/test/optimizer.pth')
# test_loop(model, 'cpu', test_iter)
# PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
# dictionary = torch.load(PATH)
# model.load_state_dict(dictionary)
train_loop(epoch, optmizer, model, loss_fn, train_iter)
# PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
# dictionary = torch.load(PATH)
# model.load_state_dict(dictionary)
# test_loop(model, 'cpu', test_iter)
若要测试数据的准确度等内容可以参考之前的博文使用LSTm进行情感分析,对test部分进行修改即可。
也可以参考下面的
PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
dictionary = torch.load(PATH)
model.load_state_dict(dictionary)
def test_loop(model, device, test_iter):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_iter:
data = data.to(device)
target = target.to(device)
output = model(data)
output = output.data.max(1, keepdim=True)[1]
output = torch.tensor(output, dtype=torch.float32)
# loss_func = loss_fn(output, target)
# test_loss += loss_func
pred = output
for index in range(0, len(pred)):
if pred[index] == target[index]:
correct += 1
test_loss /= len(test_iter.dataset)
test_losses.append(test_loss)
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_iter.dataset),
100. * correct / len(test_iter.dataset)))
test_loop(model,'cpu',test_iter)
loss /= len(test_iter.dataset)
test_losses.append(test_loss)
print(‘\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n’.format(
test_loss, correct, len(test_iter.dataset),
100. * correct / len(test_iter.dataset)))
test_loop(model,‘cpu’,test_iter)