深度学习5

一、模型保存与加载

 1、序列化方式

        保存方式:torch.save(model, "model.pkl")

        打开方式:model = torch.load("model.pkl", map_location="cpu")

​
import torch
import torch.nn as nn


class MyModle(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output

model = MyModle(input_size=128, output_size=32)
# 序列化方式保存模型对象
torch.save(model, "model.pkl")


# 注意设备问题
model = torch.load("model.pkl", map_location="cpu")
print(model)

​

2、保存模型参数

        设置需要保存的模型参数:

save_dict = {

        "init_params": {

            "input_size": 128,  # 输入特征数

            "output_size": 32,  # 输出特征数

        },

        "accuracy": 0.99,  # 模型准确率

        "model_state_dict": model.state_dict(),

        "optimizer_state_dict": optimizer.state_dict(),

    }

保存模型参数:torch.save(save_dict, "名称.pth"),一般使用 pth 作为后缀

创建新模型时调用保存的模型参数:

加载模型参数:torch.load("名称.pth")

input_size =  save_dict["init_params"]["input_size"]

import torch
import torch.nn as nn
import torch.optim as optim

class MyModle(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output

save_dict = torch.load("模型参数保存名称.pth")
model = MyModle(
    input_size=save_dict["init_params"]["input_size"],
    output_size=save_dict["init_params"]["output_size"],
)
# 初始化模型参数
model.load_state_dict(save_dict["model_state_dict"])
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 初始化优化器参数
optimizer.load_state_dict(save_dict["optimizer_state_dict"])
# 打印模型信息
print(save_dict["accuracy"])
print(model)

二、数据增强

        具体参考官方文档:Illustration of transforms — Torchvision 0.20 documentation

1、官方代码-主体

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

plt.rcParams["savefig.bbox"] = "tight"
torch.manual_seed(0)

orig_img = Image.open("../../data/1.png")

def plot(imgs, title, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    plt.title(title)

    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title="Original image")
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

    plt.show()

2、固定转换

2.1、pad 边缘填充

        就是在照片周围添加黑色框区域

padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot([orig_img] + padded_imgs, "v2.Pad")

2.2、resize 大小调整

resized_imgs = [v2.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot([orig_img] + resized_imgs, "v2.Resize")

 2.3、center crop 中心裁剪

center_crops = [
    v2.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)
]
plot([orig_img] + center_crops, "v2.CenterCrop")

2.4、five crop 周边裁剪

(top_left, top_right, bottom_left, bottom_right, center) = v2.FiveCrop(size=(100, 100))(
    orig_img
)
plot(
    [orig_img] + [top_left, top_right, bottom_left, bottom_right, center], "v2.FiveCrop"
)

3、随机转换

3.1、RandomRotation 随机旋转

rotater = v2.RandomRotation(degrees=(0, 180)) # 随机从0-180获取一个数值
rotated_imgs = [rotater(orig_img) for _ in range(4)] # 根据随机数值得到角度转变
plot([orig_img] + rotated_imgs)

3.2、RandomAffine  随机仿射

affine_transfomer = v2.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot([orig_img] + affine_imgs)

4、数据增强整合

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms, datasets, utils


def test001():
    # 定义数据增强和预处理步骤
    transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.RandomRotation(10),  # 随机旋转 ±10 度
        transforms.RandomResizedCrop( 32, scale=(0.8, 1.0) ),  # 随机裁剪到 32x32,缩放比例在0.8到1.0之间
        transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
            ),  # 随机调整亮度、对比度、饱和度、色调
        transforms.ToTensor(),  # 转换为 Tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化
        ]
    )

    # 加载 CIFAR-10 数据集,并应用数据增强
    trainset = datasets.CIFAR10(
        root="../../data", train=True, download=True, transform=transform
    )

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=4, shuffle=True, num_workers=2
    )

    # 显示增强后的图像
    dataiter = iter(trainloader)
    images, labels = next(dataiter)

    def imshow(img):
        img = img / 2 + 0.5  # 反归一化
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    imshow(utils.make_grid(images))

test001()

三、神经网络

1、人工神经元

        接收多个输入的信息并进行加权求和,使用激活函数处理得到最后结果。

        

人工神经元的设置方法是对比生物神经元的结构:

生物神经元人工神经元
细胞核节点 (加权求和 + 激活函数)
树突输入
轴突带权重的连接
突触输出

2、神经网络

        由大量人工神经元按层次结构连接而成的计算模型,上一层的神经元输出作为下一层神经元的输入,层之间的神经元并无连接。

2.1、结构

        输入层:整个神经网络的第一层,负责接收外部数据,不做任何计算。

        隐藏层:位于神经网络输入层与输出层之间的内容,进行特征提取、转化、计算等操作,一般为多层神经元组成。

        输出层:接收隐藏层的计算结果,产生预测结果或分类结果

2.2、全连接神经网络

        每一层的单个神经元都与上一层的所有神经元连接,一般用于图像分类、文本等。

3、参数初始化(权重、偏置)

        权重和偏置:model.weight、model.bias 

        初始化使用 torch.nn.init 库的方法

3.1、固定值初始化--全零化、全1化、常数化

        将参数所有数据变成固定值,一般不用于初始化权重(会破坏对称性),用于初始化偏置。

torch.nn.init.zeros_(model.weight) :参数为初始化对象,只能一个tensor。

torch.nn.init.ones_(model.weight) :参数为初始化对象,只能一个tensor。

torch.nn.init.constant_(model.weight) :参数1为初始化对象,只能一个tensor;参数2为设置的浮点数。

from  torch.nn import Linear
import torch.nn

model = Linear(4,1)
# 参数只有一个,初始化对象,不能一次初始化多个对象
torch.nn.init.zeros_(model.weight)
torch.nn.init.zeros_(model.bias)
print(model.weight,model.bias)


torch.nn.init.ones_(model.weight)
torch.nn.init.ones_(model.bias)
print(model.weight,model.bias)

model = Linear(4,1)
# 使用叶子节点来初始化为0
model.weight.detach().zero_()
model.bias.detach().zero_()
print(model.weight,model.bias)

model = Linear(4,1)
# 参数第一个为初始化对象,第二个为设定的浮点数;
# 不能一次初始化多个对象
torch.nn.init.constant_(model.weight,5.)
torch.nn.init.constant_(model.bias,5.)
print(model.weight,model.bias)

3.2、随机初始化 

        normal_、uniform_:将权重初始化为随机的小值,通常从正态分布或均匀分布中采样;能避免对称性破坏。

from  torch.nn import Linear
import torch.nn

model = Linear(4,1)
# 参数第一个为初始化对象;
# 参数2、3为均值和标准差,默认 0,1标准正太分布
torch.nn.init.normal_(model.weight) 
torch.nn.init.normal_(model.bias) 
print(model.weight,model.bias)

# 参数第二个和第三个为下界和上界,默认0-1
torch.nn.init.uniform_(model.weight, 0,1) 
torch.nn.init.uniform_(model.bias) 
print(model.weight,model.bias)

3.3、Xavier 初始化

        对随机初始化添加取值限制。

        

平衡了输入和输出的方差,适合Sigmoid 和 Tanh 激活函数或浅层网络。

from  torch.nn import Linear
import torch.nn

model = Linear(4,1)
# 参数第一个为初始化对象;
# 第二个参数 gain 是缩放因子
torch.nn.init.xavier_normal_(model.weight) 
print(model.weight)

torch.nn.init.xavier_uniform_(model.weight) 
print(model.weight)

"""
常见的 gain 值:
    线性激活函数:gain = 1.0(默认值)
    Sigmoid 激活函数:gain = 1.0
    Tanh 激活函数:gain = 5/3(约等于 1.653)
    ReLU 激活函数:gain = sqrt(2)(约等于 1.414)
    Leaky ReLU 激活函数:gain = sqrt(2 / (1 + negative_slope^2)),其中 negative_slope 是 Leaky ReLU 的负斜率,默认值为 0.01。
"""

3.4、He初始化 (kaiming 初始化 )

        专门为 ReLU 激活函数设计;权重从以下分布中采样   W\sim\mathrm{N}\left(0,\frac{2}{n_\mathrm{in}}\right)n_{\text{in}} 是当前层的输入神经元数量。

from  torch.nn import Linear
import torch.nn

model = Linear(4,1)
# 参数第一个为初始化对象;
# a 为负斜率的值(relu负数为0,所以此参数只有在relu衍生的函数有效 leaky_relu)
# nonlinearity 默认 leaky_relu
# mode 默认 fan-in 使用输入单元数量计算初始化值

torch.nn.init.kaiming_normal_(model.weight) 
print(model.weight)

torch.nn.init.kaiming_uniform_(model.weight) 
print(model.weight)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/923748.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis五大基本类型——Zset有序集合命令详解(命令用法详解+思维导图详解)

目录 一、Zset有序集合类型介绍 二、常见命令 1、ZADD 2、ZCARD 3、ZCOUNT 4、ZRANGE 5、ZREVRANGE 6、ZRANGEBYSCORE 7、ZREVRANGEBYSCORE 8、ZPOPMAX 9、ZPOPMIN 10、ZRANK 11、ZREVRANK 12、ZSCORE 13、ZREM 14、ZREMRANGEBYRANK 15、ZREMRANGEBYSCORE 16…

ARM架构 AArch64 基础知识介绍

介绍 aarch64是 ARM 架构的 64 位版本,它是 ARMv8 架构的一部分,被设计用来提供更高的性能和更大的地址空间,同时保持与 32 位 ARM 架构的兼容性。AArch64 是 ARMv8 的 64 位指令集架构(ISA),它提供了丰富的…

Rust中Tracing 应用指南

欢迎来到这篇全面的Rust跟踪入门指南。Rust 的tracing是一个用于应用程序级别的诊断和调试的库。它提供了一种结构化的、异步感知的方式来记录日志和跟踪事件。与传统的日志记录相比,tracing能够更好地处理复杂的异步系统和分布式系统中的事件跟踪,帮助开…

极狐GitLab 17.6 正式发布几十项与 DevSecOps 相关的功能【三】

GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署极狐GitLab。 学习极狐GitLab 的相关资料: 极狐GitLab 官网极狐…

WinFrom调用webapi接口另一个方法及其应用实例

1.调用接口方法 代码如下&#xff1a; public class WebAPI{#region WebAPI调用 public async Task<string> Call_Webapi(string Url, string Json) //url传入的是接口名称&#xff0c;json传入的是接口参数{string responseBody string.Empty; //responseBod…

elasticsearch的索引模版使用方法

5 索引模版⭐️⭐️⭐️⭐️⭐️ 索引模板就是创建索引时要遵循的模板规则索引模板仅对新创建的索引有效&#xff0c;已经创建的索引并不受索引模板的影响 5.1 索引模版的基本使用 1.查看所有的索引模板 GET 10.0.0.91:9200/_index_template2.创建自定义索引模板 xixi &…

从零开始学GeoServer源码(二)添加支持arcgis切片功能

文章目录 参考文章环境背景1、配置打包好的程序1.1、下载GeoServer的war包1.2、下载GeoWebCache1.3、拷贝jar包1.4、修改配置文件1.4.1、拷贝geowebcache-arcgiscache-context.xml1.4.2、修改geowebcache-core-context.xml1.4.3、修改geowebcache-servlet.xml 1.5、配置切片信息…

Redis 可观测最佳实践

Redis 介绍 Redis 是一个开源的高性能键值对&#xff08;key-value&#xff09;数据库。它通常用作数据库、缓存和消息代理。Redis 支持多种类型的数据结构&#xff0c;Redis 通常用于需要快速访问的场景&#xff0c;如会话缓存、全页缓存、排行榜、实时分析等。由于其高性能和…

HarmonyOs鸿蒙开发实战(21)=>组件间通信@ohos/liveeventbus

1.简介 LiveEventBus是一款消息总线&#xff0c;具有生命周期感知能力&#xff0c;支持Sticky&#xff0c;支持跨进程&#xff0c;支持跨APP发送消息。 2.下载安装 ohpm install ohos/liveeventbus 3.订阅&#xff0c;注册监听 4.发送事件 5. 完成 > 记得关注博主&#xff…

深度学习使用LSTM实现时间序列预测

大家好&#xff0c;LSTM是一种特殊的循环神经网络&#xff08;RNN&#xff09;架构&#xff0c;它被设计用来解决传统RNN在处理长序列数据时的梯度消失和梯度爆炸问题&#xff0c;特别是在时间序列预测、自然语言处理和语音识别等领域中表现出色。LSTM的核心在于其独特的门控机…

用Tauri框架构建跨平台桌面应用:1、Tauri快速开始

Tauri 是一个构建适用于所有主流桌面和移动平台的轻快二进制文件的框架。开发者们可以集成任何用于创建用户界面的可以被编译成 HTML、JavaScript 和 CSS 的前端框架&#xff0c;同时可以在必要时使用 Rust、Swift 和 Kotlin 等语言编写后端逻辑。 Tauri 是什么&#xff1f; |…

Elasticsearch对于大数据量(上亿量级)的聚合如何实现?

大家好&#xff0c;我是锋哥。今天分享关于【Elasticsearch对于大数据量&#xff08;上亿量级&#xff09;的聚合如何实现&#xff1f;】面试题。希望对大家有帮助&#xff1b; Elasticsearch对于大数据量&#xff08;上亿量级&#xff09;的聚合如何实现&#xff1f; 1000道 …

C语言:C语言实现对MySQL数据库表增删改查功能

基础DOME可以用于学习借鉴&#xff1b; 具体代码 #include <stdio.h> #include <mysql.h> // mysql 文件&#xff0c;如果配置ok就可以直接包含这个文件//宏定义 连接MySQL必要参数 #define SERVER "localhost" //或 127.0.0.1 #define USER "roo…

Learn Git Branching 学习笔记

网址&#xff1a;Learn Git Branching 一、基础篇 1.1 git commit 1.1.1 示例&#xff08;git commit&#xff09; git commit 1.1.2 题目&#xff08;两次提交记录&#xff09; git commit git commit 前 后 1.2 git branch 1.2.1 示例&#xff08;git branch <>、git …

设计模式:责任链实现数据流风格的数据处理

数据流风格 数据流风格是软件架构中的一种风格&#xff0c;主要是面向数据&#xff0c;用于进行流式的数据处理&#xff1b;数据流风格的代表有管道-过滤器风格和批处理序列风格&#xff0c;这里主要是指管道-过滤器风格。 管道-过滤器风格就像其名字一样&#xff0c;是以一个…

BERT简单理解;双向编码器优势

目录 BERT简单理解 一、BERT模型简单理解 二、BERT模型使用举例 三、BERT模型的优势 双向编码器优势 BERT简单理解 (Bidirectional Encoder Representations from Transformers)模型是一种预训练的自然语言处理(NLP)模型,由Google于2018年推出。以下是对BERT模型的简…

摄像头原始数据读取——V4L2(mmap模式,V4L2_MEMORY_MMAP)

摄像头原始数据读取——V4L2(mmap模式,V4L2_MEMORY_MMAP) 内存映射模式&#xff0c;是将设备在内核态申请的用于存储视频数据的物理内存映射到用户空间&#xff0c;使得用户应用程序可以直接访问和操作设备数据物理内存&#xff0c;避免了数据的拷贝。因此采集速度较快&#x…

SpringCloud框架学习(第五部分:SpringCloud Alibaba入门和 nacos)

目录 十二、SpringCloud Alibaba入门简介 1. 基本介绍 2.作用 3.版本选型 十三、 SpringCloud Alibaba Nacos服务注册和配置中心 1.简介 2.各种注册中心比较 3.下载安装 4.Nacos Discovery服务注册中心 &#xff08;1&#xff09; 基于 Nacos 的服务提供者 &#xf…

Ubuntu下用Docker部署群晖系统---Virtual DSM --zerotier实现连接

Ubuntu下用Docker部署群晖系统—Virtual DSM --zerotier实现连接 1. Docker 安装 安装最新docker curl -fsSL get.docker.com -o get-docker.sh sudo sh get-docker.sh sudo docker run hello-world2.docker-compose 安装 sudo pip install docker-compose测试安装是否成功…

【排版教程】Word、WPS 分节符(奇数页等) 自动变成 分节符(下一页) 解决办法

毕业设计排版时&#xff0c;一般要求每章节的起始页为奇数页&#xff0c;空白页不显示页眉和页脚。具体做法如下&#xff1a; 1 Word 在一个章节的内容完成后&#xff0c;在【布局】中&#xff0c;点击【分隔符】&#xff0c;然后选择【奇数页】 这样在下一章节开始的时&…