Python语言 AI框架:Mindspore
1.模型构建
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
model = Network()
print(model)
2.模型保存
mindspore.save_checkpoint(model, "model.ckpt")
3.模型导出-mindir格式
除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
4.加载保存模型
要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint和load_param_into_net方法加载参数。
model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load) # 正确输出[ ]
5.加载导出模型
nn.GraphCell
仅支持图模式。
mindspore.set_context(mode=mindspore.GRAPH_MODE)
graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
MindIR
同时保存了Checkpoint
和模型结构,因此需要定义输入Tensor
来获取输入shape。
运行结果-模型保存情况如下: