PyTorch常用工具(2)预训练模型

文章目录

  • 前言
  • 2 预训练模型

前言

在训练神经网络的过程中需要用到很多的工具,最重要的是数据处理、可视化和GPU加速。本章主要介绍PyTorch在这些方面常用的工具模块,合理使用这些工具可以极大地提高编程效率。

由于内容较多,本文分成了五篇文章(1)数据处理(2)预训练模型(3)TensorBoard(4)Visdom(5)CUDA与小结。

整体结构如下:

  • 1 数据处理
    • 1.1 Dataset
    • 1.2 DataLoader
  • 2 预训练模型
  • 3 可视化工具
  • 3.1 TensorBoard
  • 3.2 Visdom
  • 4 使用GPU加速:CUDA
  • 5 小结

全文链接:

  1. PyTorch中常用的工具(1)数据处理
  2. PyTorch常用工具(2)预训练模型
  3. PyTorch中常用的工具(3)TensorBoard
  4. PyTorch中常用的工具(4)Visdom
  5. PyTorch中常用的工具(5)使用GPU加速:CUDA

2 预训练模型

除了加载数据,并对数据进行预处理之外,torchvision还提供了深度学习中各种经典的网络结构以及预训练模型。这些模型封装在torchvision.models中,包括经典的分类模型:VGG、ResNet、DenseNet及MobileNet等,语义分割模型:FCN及DeepLabV3等,目标检测模型:Faster RCNN以及实例分割模型:Mask RCNN等。读者可以通过下述代码使用这些已经封装好的网络结构与模型,也可以在此基础上根据需求对网络结构进行修改:

from torchvision import models
# 仅使用网络结构,参数权重随机初始化
mobilenet_v2 = models.mobilenet_v2()
# 加载预训练权重
deeplab = models.segmentation.deeplabv3_resnet50(pretrained=True)

下面使用torchvision中预训练好的实例分割模型Mask RCNN进行一次简单的实例分割:

In: from torchvision import models
    from torchvision import transforms as T
    from torch import nn
    from PIL import Image
    import numpy as np
    import random
    import cv2

    # 加载预训练好的模型,不存在的话会自动下载
    # 预训练好的模型保存在 ~/.torch/models/下面
    detection = models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    detection.eval()
    def predict(img_path, threshold):
        # 数据预处理,标准化至[-1, 1],规定均值和标准差
        img = Image.open(img_path)
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
        ])
        img = transform(img)
        # 对图像进行预测
        pred = detection([img])
        # 对预测结果进行后处理:得到mask与bbox
        score = list(pred[0]['scores'].detach().numpy())
        t = [score.index(x) for x in score if x > threshold][-1]
        mask = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()
        pred_boxes = [[(i[0], i[1]), (i[2], i[3])] \
                      for i in list(pred[0]['boxes'].detach().numpy())]
        pred_masks = mask[:t+1]
        boxes = pred_boxes[:t+1]
        return pred_masks, boxes

Transforms中涵盖了大部分对Tensor和PIL Image的常用处理,这些已在上文提到,本节不再详细介绍。需要注意的是转换分为两步,第一步:构建转换操作,例如transf = transforms.Normalize(mean=x, std=y);第二步:执行转换操作,例如output = transf(input)。另外还可以将多个处理操作用Compose拼接起来,构成一个处理转换流程。

In: # 随机颜色,以便可视化
    def color(image):
        colours = [[0, 255, 255], [0, 0, 255], [255, 0, 0]]
        R = np.zeros_like(image).astype(np.uint8)
        G = np.zeros_like(image).astype(np.uint8)
        B = np.zeros_like(image).astype(np.uint8)
        R[image==1], G[image==1], B[image==1] = colours[random.randrange(0,3)]
        color_mask = np.stack([R,G,B],axis=2)
        return color_mask
In: # 对mask与bounding box进行可视化
    def result(img_path, threshold=0.9, rect_th=1, text_size=1, text_th=2):
        masks, boxes = predict(img_path, threshold)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        for i in range(len(masks)):
            color_mask = color(masks[i])
            img = cv2.addWeighted(img, 1, color_mask, 0.5, 0)
            cv2.rectangle(img, boxes[i][0], boxes[i][1], color=(255,0,0), thickness=rect_th)
        return img
In: from matplotlib import pyplot as plt
    img=result('data/demo.jpg')
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    img_result = plt.imshow(img)

TensorBoard界面

上述代码完成了一个简单的实例分割任务。如上图所示,Mask RCNN能够分割出该图像中的部分实例,读者可考虑对预训练模型进行微调,以适应不同场景下的不同任务。注意:上述代码均在CPU上进行,速度较慢,读者可以考虑将数据与模型转移至GPU上,具体操作可以参考第4节。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/282970.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2023年终总结|回顾学习Tensorflow、Keras的历程

2023年4月,初探TensorFlow2.0,对比了1.0版本的差异。接着,学习了TensorFlow2.0的常量矩阵、四则运算以及常用函数。学习了数据切割、张量梯度计算、遍历元素、类别索引转换等技巧,并掌握了CNN输出特征图形状的计算方法。 在数据处…

【漏洞复现】企望制造ERP系统 RCE漏洞

漏洞描述 企望制造ERP系统是畅捷通公司开发的一款领先的生产管理系统,它以集成化管理为核心设计理念,通过模块化机制,帮助企业实现生产、采购、库存等方面的高效管理。该系统存在RCE远程命令执行漏洞,恶意攻击者可利用此漏洞进行…

流逝的时光

文章目录 创作历程展望2024 创作历程 自2019.6.28注册csdn,期间断断续续的通过其查找相应资料,受益颇多 今研一,发现论文看了又忘,于是借此平台来记录,可以看到基本都是基于原论文进行翻译,并没有所思所想&…

今年努力输出的嵌入式Linux视频

今年努力了一波,几个月周六日无休,自己在嵌入式linux工作有些年头,结合自己也是一直和SLAM工程师对接,所以输出了一波面向SLAM算法工程师Linux课程,当然嵌入式入门的同学也可以学习。下面是合作的官方前面发的宣传文章…

《Spring Cloud学习笔记:微服务保护Sentinel + JMeter快速入门》

Review 解决了服务拆分之后的服务治理问题:Nacos解决了服务治理问题OpenFeign解决了服务之间的远程调用问题网关与前端进行交互,基于网关的过滤器解决了登录校验的问题 流量控制:避免因为突发流量而导致的服务宕机。 隔离和降级&#xff1a…

浅学lombok

Lombok(Project Lombok)是一个用于 Java 编程语言的开源库,旨在减少 Java 代码中的冗余和样板代码,提高开发人员的生产力。它通过使用注解来自动生成 Java 类的常见方法和代码,从而使开发人员能够编写更简洁、更具可读…

Django Rest Framework(DRF)框架搭建步骤,包含部分错误解决

一、初步搭建项目 1.使用PyCharm 2021创建Djiango项目,配置如下(假设应用名叫djiango_python) Python (3.6, 3.7, 3.8, 3.9, 3.10, 3.11)> 当前版本 3.8.6Django &a…

雪花算法(Snowflake)介绍和Java实现

1、雪花算法介绍 (1) 雪花算法(SnowFlake)是分布式微服务下生成全局唯一ID,并且可以做到去中心化的常用算法,最早是Twitter公司在其内部的分布式环境下生成ID的方式。 雪花算法的名字可以这么理解,世界上没有两片完全相同的雪花,…

shell shell脚本编写常用命令 语法 shell 脚本工具推荐

shell 脚本 计算机语言 Shebang 定义解释器 主要定义,您的脚本是用什么语言写的 #!/usr/bin/python //定义这是一个python语言#!/bin/bash //定义这是一个shell语言 echo SHELL我们执行的 linux 命令的时候,其实是使用 /bin/bash 这个二进制文…

【模拟电路】基础理论与实际应用

一、毫安时和毫瓦时 二、开关电路 三、继电器 四、半导体 五、二极管 六、三极管 七、三极管应用案例 一、毫安时和毫瓦时 毫安时(mAh)和毫瓦时(mWh)是两个不同的物理量,它们分别表示电量和能量的度量单位。下面的图…

手把手教你绘制和解读实用R列线图(Nomogram):从入门到精通

一、引言 列线图(Nomogram)是一种常用的数据可视化工具,它能够直观地展示多个变量之间的关系,并帮助我们理解和解释复杂的数据模式。通过绘制列线图,我们可以将各种变量的影响和相互关联转化为图形化的表示&#xff0c…

2024-01-01 事业-代号s-科特勒《营销管理》-分析

摘要: 2024-01-01 事业-代号s-科特勒《营销管理》-分析 科特勒《营销管理》-分析 营销管理 - 思维导图 01 理解营销管理 这本书不仅从概念出发介绍了营销管理的定义、职能和计划,还拆解了每一个管理环节策划的具体实施方法。通过下面这张思维导图,我们…

考研后SpringBoot复习2—容器底层相关注解

考研后SpringBoot复习2 SpringBoot底层注解学习 与容器功能相关的注解与springboot的底层原理密切相关 组件添加注解configuration Spring Ioc容器部分回顾 包括在配置中注册,开启包扫描和注解驱动开发等需要在进行重新的学习回顾 实例 package com.dzu.boot;imp…

2022–2023学年2021级计算机科学与技术专业数据库原理 (A)卷

一、单项选择题(每小题1.5分,共30分) 1、构成E—R模型的三个基本要素是( B )。 A.实体、属性值、关系 B.实体、属性、联系 C.实体、实体集、联系 D.实体、实体…

【第5期】前端Vue使用Proxy+Vuex(store、mutations、actions)跨域调通本地后端接口

本期简介 本期要点 本地开发前后端如何跨域调用全局请求、响应处理拦截器处理封装HTTP请求模块编写API请求映射到后端API数据的状态管理 一、 本地开发前后端如何跨域调用 众所周知,只要前端和后端的域名或端口不一样,就存在跨域访问,例如&…

模型量化之AWQ和GPTQ

什么是模型量化 模型量化(Model Quantization)是一种通过减少模型参数表示的位数来降低模型计算和存储开销的技术。一般来说,模型参数在深度学习模型中以浮点数(例如32位浮点数)的形式存储,而模型量化可以…

appium入门基础

介绍 appium支持在不同平台的UI自动化,如web,移动端,桌面端等。还支持使用java,python,js等语言编写自动化代码。主要用于自动化测试脚本,省去重复的手动操作。 Appium官网 安装 首先必须环境有Node.js用于安装Appium。 总体来…

OpcUaHelper实现西门子OPC Server数据交互

Opc ua客户端类库,基于.net 4.6.1创建,基于官方opc ua基金会跨平台库创建,方便的实现和OPC Server进行数据交互。 FormBrowseServer 在开发客户端之前,需要使用本窗口来进行查看服务器的节点状态,因为在请求服务器的节点数据之前,必须知道节点的名称,而节点的名称可以…

Docker之网络配置

目录 1.网络概念 网络相关的有ip,子网掩码,网关,DNS,端口号 1.1 ip是什么? ip是唯一定位一台网上计算机 Ip地址的分类: IPV4: 4字节32位整数,并分成4段8位的二进制数,每8位之间用圆点隔开,每8位整数可以转换为一个0~255的十进制整数 【例…

在香橙派5 Plus上搭建Gitlab

作为一个码农,一定知道Github这个最大的成人交友网站。但是Github在国内不稳定,经常拉不下来代码,也就无法推送代码。为了更方便的使用,顺便更好地了解Git工具,决定在香橙派5 Plus上搭建一个属于自己的代码仓库。 1、…