1 介绍
- 在 PyTorch 中,
nn.ModuleDict
是一个方便的容器,用于存储一组子模块(即nn.Module
对象)的字典 - 这个容器主要用于动态地管理多个模块,并通过键来访问它们,类似于 Python 的字典
2 特点
- 组织性
nn.ModuleDict
提供了一种将多个模块有序组织在一起的方法。- 这有助于让代码更加结构化,易于理解和维护
- 动态操作
- 可以像操作普通字典那样添加或删除模块
- 例如使用
module_dict['key'] = module
添加模块,使用del module_dict['key']
删除模块
- 自动参数注册
- 当将模块添加到
ModuleDict
中时,它们的参数会自动注册到整个网络中,确保在模型训练时这些参数可以被识别和更新
- 当将模块添加到
3 例子
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layers = nn.ModuleDict({
'linear': nn.Linear(10, 20),
'activation': nn.ReLU()
})
def forward(self, x):
x = self.layers['linear'](x)
x = self.layers['activation'](x)
return x
- 在构造函数中,我们使用
ModuleDict
来存储一个线性层和一个激活层,并在前向传播forward
方法中通过键名访问这些层
4 主要方法
clear | 清空 ModuleDict 中的所有条目 |
items | 返回一个可迭代对象,包含 |
keys | 返回一个可迭代对象,包含 |
pop | 从 ModuleDict 中移除指定的键,并返回其对应的模块 |
update(modules) | 使用另一个映射或键值对迭代器更新 ModuleDict ,如果存在相同的键,则会覆盖原有的条目 |
values | 返回一个可迭代对象,包含 ModuleDict 的所有模块值 |