一、步骤
我们先定义一个dict,每一个key对应的value都是一个list。
loss_history = dict((k, []) for k in ["epoch", "train_loss", "val_loss"])
每一轮或者每一次迭代的损失都通过list记录下来。
loss_history["epoch"].append(1)
loss_history["train_loss"].append(0.1)
loss_history["val_loss"].append(0.1)
loss_history["epoch"].append(2)
loss_history["train_loss"].append(0.05)
loss_history["val_loss"].append(0.05)
使用matplotlib.pyplot工具来绘制曲线图
plt.figure()
plt.title('loss during training') #标题
plt.plot(loss_history["epoch"], loss_history["train_loss"], label="train_loss")
plt.plot(loss_history["epoch"], loss_history["val_loss"], label="valid_loss")
plt.legend()
plt.grid()
plt.show()
plt.savefig('loss figure')
二、完整代码
import matplotlib.pyplot as plt
# dict + list
loss_history = dict((k, []) for k in ["epoch", "train_loss", "val_loss"])
loss_history["epoch"].append(1)
loss_history["train_loss"].append(0.1)
loss_history["val_loss"].append(0.1)
loss_history["epoch"].append(2)
loss_history["train_loss"].append(0.05)
loss_history["val_loss"].append(0.05)
plt.figure()
plt.title('loss during training') #标题
plt.plot(loss_history["epoch"], loss_history["train_loss"], label="train_loss")
plt.plot(loss_history["epoch"], loss_history["val_loss"], label="valid_loss")
plt.legend()
plt.grid()
plt.show()
plt.savefig('loss figure')
三、效果展示
很实用~