作为一个初学者,发现构建一个简单的线性模型都能看到nn.Module的身影,初学者疑惑了,nn.Module到底是干什么的,如此形影不离,了解之后,很牛。
1、nn.Module是所有层的父类,比如Linear、BatchNorm2d、Conv2d、ReLU、Sigmoid、ConvTranposed、Dropout等等这些都是它的儿子(子类),你可以直接拿来使用。
2、nn.Module还支持一个nn.Module嵌套另一个nn.Module。
3、并且可以自动完成forward,你只需要nn.Sequential()这个容器就可以了,代码示例如下:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(BasicNet(),
nn.ReLU(),
nn.Linear(3, 2))
def forward(self, x):
return self.net(x)
4、深度学习中参数可谓是量产,如果手动进行参数管理将会是一个庞大的工程,导师问你在干什么就不是在训练模型了,而是处理那些无处安放的参数,而使用nn.Module就提供的parameters就可以秒出结果,代码示例如下:
for name, t in net.named_parameters():
print('parameters:', name, t.shape)
# parameters: net.0.net.weight torch.Size([3, 4])
# parameters: net.0.net.bias torch.Size([3])
# parameters: net.2.weight torch.Size([2, 3])
# parameters: net.2.bias torch.Size([2])
然后将参数直接传入优化器进行优化,代码示例如下:
optimizer=optim.SGD(net.parameters(),lr=1e-3)
5、有很多的孩子,并且你还可以很简单的知道他孩子长什么样,我先介绍一下他的孩子们:
我们可以通过 net.named_children()了解他的亲孩子(children),也就是直系亲属,net.named_modules()了解他所有的孩子(modules),直系亲属外亲都算,代码示例如下:
for name, m in net.named_children():
print('children:', name, m)
# children: net Sequential(
# (0): BasicNet(
# (net): Linear(in_features=4, out_features=3, bias=True)
# )
# (1): ReLU()
# (2): Linear(in_features=3, out_features=2, bias=True)
# )
for name, m in net.named_modules():
print('modules:', name, m)
# modules: Net(
# (net): Sequential(
# (0): BasicNet(
# (net): Linear(in_features=4, out_features=3, bias=True)
# )
# (1): ReLU()
# (2): Linear(in_features=3, out_features=2, bias=True)
# )
# )
# modules: net Sequential(
# (0): BasicNet(
# (net): Linear(in_features=4, out_features=3, bias=True)
# )
# (1): ReLU()
# (2): Linear(in_features=3, out_features=2, bias=True)
# )
# modules: net.0 BasicNet(
# (net): Linear(in_features=4, out_features=3, bias=True)
# )
# modules: net.0.net Linear(in_features=4, out_features=3, bias=True)
# modules: net.1 ReLU()
# modules: net.2 Linear(in_features=3, out_features=2, bias=True)
6、可以非常方便的将网络运行在不同的设备,这里的设备是指cuda、gpu之类的,代码示例如下:
device = torch.device('cuda')
net = Net()
net.to(device)
7、方便对模型进行保存和加载,一个模型一般需要训练好久,但是我们并没有如此连续的时间,我们可以将现在训练好的模型进行保存,下次加载继续训练,代码示例如下:
# 加载
net.load_state_dict(torch.load('ckpt.mdl'))
# 保存
torch.save(net.state_dict(),'ckpt.mdl')
8、方便进行训练和测试状态的切换,代码示例如下:
# 训练
net.train()
# 测试
net.eval()
9、可以实现自己构建的模型,代码示例如下:
# 以下是一个展平的实现
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, input):
return input.view(input.size(0), -1) #[b,打平],保留b,其他的全部打平
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(1*14*14, 10))
def forward(self, x):
return self.net(x)
# 构建自己的线性层
class MyLinear(nn.Module):
def __init__(self, inp, outp):
super(MyLinear, self).__init__()
# 这里使用nn.Parameter代表了可以回传给nn.model,进行更新,所以不用写requires_grad = True
# requires_grad = True
self.w = nn.Parameter(torch.randn(outp, inp))
self.b = nn.Parameter(torch.randn(outp))
def forward(self, x):
x = x @ self.w.t() + self.b
return x
如果你还有什么更好的idea,欢迎分享!!!