一、matplotlib库
在我们自己训练模型时,常常会使用matplotlib库来绘制oss和accuracy的曲线图,帮助我们分析模型的训练表现。
matplotlib库安装:pip install matplotlib
二、代码
import matplotlib.pyplot as plt
import torch
import torch.optim as optim # 导入优化器模块
#------------------------------------------------------------------#
# 定义损失函数
#------------------------------------------------------------------#
def loss_fn(y_true, y_pred):
return torch.mean((y_true - y_pred)**2)
#------------------------------------------------------------------#
# 定义模型
#------------------------------------------------------------------#
model = torch.nn.Linear(10, 1)
#------------------------------------------------------------------#
# 定义训练,验证数据
#------------------------------------------------------------------#
x_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)
x_val = torch.randn(1000, 10)
y_val = torch.randn(1000, 1)
#------------------------------------------------------------------#
# 定义优化器
#------------------------------------------------------------------#
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用 Adam 优化器,学习率为 0.001
#------------------------------------------------------------------#
# 定义损失函数
#------------------------------------------------------------------#
train_loss_list = []
val_loss_list = []
#------------------------------------------------------------------#
# 开始训练
#------------------------------------------------------------------#
for epoch in range(10000):
# ------------------------------------------------------------------#
# 训练
# ------------------------------------------------------------------#
# ------------------------------------------------------------------#
# 前向传播
# ------------------------------------------------------------------#
y_pred = model(x_train)
# ------------------------------------------------------------------#
# 计算损失
# ------------------------------------------------------------------#
training_loss = loss_fn(y_train, y_pred)
train_loss_list.append(training_loss.item())
# ------------------------------------------------------------------#
# 反向传播
# ------------------------------------------------------------------#
training_loss.backward()
# ------------------------------------------------------------------#
# 更新参数
# ------------------------------------------------------------------#
optimizer.step()
# ------------------------------------------------------------------#
# 展示训练损失
# ------------------------------------------------------------------#
if epoch % 10 == 0:
print(f"epoch {epoch}:training loss {training_loss.item()}")
# ------------------------------------------------------------------#
# 验证
# ------------------------------------------------------------------#
# ------------------------------------------------------------------#
# 前向传播
# ------------------------------------------------------------------#
y_pred = model(x_val)
# ------------------------------------------------------------------#
# 计算损失
# ------------------------------------------------------------------#
val_loss = loss_fn(y_val, y_pred)
val_loss_list.append(val_loss.item())
# ------------------------------------------------------------------#
# 展示验证损失
# ------------------------------------------------------------------#
if epoch % 10 == 0:
print(f"epoch {epoch}:validate loss {val_loss.item()}")
# ------------------------------------------------------------------#
# 记录训练,验证损失
# ------------------------------------------------------------------#
plt.plot(train_loss_list,color="red",label="training_loss")
plt.plot(val_loss_list,color="blue",label="val_loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc='lower right')
plt.show()
运行结果
查看