TVRNet网络PyTorch实现

文章目录

    • 文章地址
    • 网络各层结构
    • 代码实现

文章地址

  • An End-to-End Traffic Visibility Regression Algorithm
  • 文章通过训练搜集得到的真实道路图像数据集(Actual Road dense image Dataset, ARD),通过专业的能见度计和多人标注,获得可靠的能见度标签数据集。构建网络,进行训练,获得了较好的能见度识别网络。网络包括特征提取​、多尺度映射​、特征融合​、非线性输出(回归范围为[0,1],需要经过(0,0),(1,1)改用修改的sigmoid函数,相较于ReLU更好)。结构如下​
    在这里插入图片描述

网络各层结构

在这里插入图片描述

  • 我认为红框位置与之相应的参数不匹配,在Feature Extraction部分Reshape之后得到的特征图大小为4124124。紧接着接了一个卷积层Conv,显示输入是3128128
  • 第二处红框,MaxPool的kernel设置为88,特征图没有进行padding,到全连接层的输入变为64117*117,参数不对应
    在这里插入图片描述

代码实现

"""
    Based on the ideas of the below paper, using PyTorch to build TVRNet.
    Reference: Qin H, Qin H. An end-to-end traffic visibility regression algorithm[J]. IEEE Access, 2021, 10: 25448-25454.​
    @weishuo
"""

import torch
from torch import nn
import math


class Inception(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(Inception, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_planes, out_planes, kernel_size=5, padding=2)
        self.conv7 = nn.Conv2d(in_planes, out_planes, kernel_size=7, padding=3)

    def forward(self, x):
        out_1 = self.conv1(x)
        out_3 = self.conv3(x)
        out_5 = self.conv5(x)
        out_7 = self.conv7(x)

        out = torch.cat((out_1, out_3, out_5, out_7), dim=1)
        return out

def modify_sigmoid(x):
    return 1 / (1 + torch.exp(-10*(x-0.5)))

class TVRNet(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(TVRNet, self).__init__()
        # (B, 3, 224, 224)  ——>  (B, 3, 220, 220)
        self.FeatureExtraction_onestep = nn.Sequential(nn.Conv2d(in_planes, 20, kernel_size=5, padding=0),
                                                       nn.ReLU(inplace=True),)
        self.FeatureExtraction_maxpool = nn.MaxPool2d((5, 1))

        self.MultiScaleMapping = nn.Sequential(Inception(4, 16),
                                               nn.ReLU(inplace=True),
                                               nn.MaxPool2d(kernel_size=8))

        self.FeatureIntegration = nn.Sequential(nn.Linear(46656, 100),
                                                nn.ReLU(inplace=True),
                                                nn.Dropout(0.4),
                                                nn.Linear(100, out_planes))

        self.NonLinearRegression = modify_sigmoid


    def forward(self, x):
        x = self.FeatureExtraction_onestep(x)
        x = x.view((x.shape[0], 1, x.shape[1], -1))
        x = self.FeatureExtraction_maxpool(x)
        x = x.view(x.shape[0], x.shape[2], int(math.sqrt(x.shape[3])), int(math.sqrt(x.shape[3])))
        # print(x.shape)

        x = self.MultiScaleMapping(x)
        # print(x.shape)
        x = x.view(x.shape[0], -1)

        x = self.FeatureIntegration(x)
        out = self.NonLinearRegression(x)

        return out


if __name__ == '__main__':
    a = torch.randn(1,3,224,224)
    net = TVRNet(3,3)
    b = net(a)
    print(b.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/107252.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Qt QWebEngine 更换语言

背景 使用Qt QWebEngine开发的应用,在一些场景下,会显示英文文本,比如右键、JS弹出的对话框,所以需要进行汉化,更改语言。 准备翻译文件 Qt有提供翻译好的ts文件,我们可以直接下载ts文件qtwebengine_zh_…

世界前沿技术发展报告2023《世界航空技术发展报告》(二)军用飞机技术

(二)军用飞机技术 1.作战飞机1.1 美俄对第五代战斗机进行升级改进1.2 美欧第六代战斗机技术取得新进展1.3 美国B-21隐身轰炸机正式亮相 2.支援飞机2.1 美国空军拟研制翼身融合布局运输/加油机2.2 美欧厂商积极参加北约未来预警机技术研究项目2.3 美国空军…

1817_ChibiOS的RT线程

全部学习汇总: GreyZhang/g_ChibiOS: I found a new RTOS called ChibiOS and it seems interesting! (github.com) 1. 关于线程,有几个概念需要弄清楚:声明、生命循环、延迟、线程引用、线程队列、线程时间、优先级管理、调度。 2. 两个声明…

线性表的定义和基本操作

线性表的定义和基本操作 一、线性表的定义 线性表(Linear List)是具有相同数据类型的n(n>0)个数据元素的有限序列,其中n为表长,当n0时线性表是一个空表。若用L命名线性表,则其一般表示为 L (a1,a2,...,ai,ai1,.…

JMeter + Ant + Jenkins持续集成-接口自动化测试

需要安装的工具: jdk1.8jmeter3.2ant1.9jenkins2.1 1、Jdkwin7系统如何安装jdk及环境变量的配置-百度经验 安装包安装设置环境变量验证是否安装正确 Java -version检查,如下就代表安装成功了,环境变量设置就去搜索了,网上很多…

TimeGPT:时间序列预测的第一个基础模型

时间序列预测领域在最近的几年有着快速的发展,比如N-BEATS、N-HiTS、PatchTST和TimesNet。 大型语言模型(llm)最近在ChatGPT等应用程序中变得非常流行,因为它们可以适应各种各样的任务,而无需进一步的训练。 这就引出了一个问题:时间序列的…

File相关方法2

一.获取当前目录下所有一级文件名称 1.代码 package org.example;import java.io.File;public class day03 {public static void main(String[] args) {//获取当前目录下所有一级文件名称final File f1 new File("d:/temp");final String[] name f1.list();for (…

VDA到Excel方案介绍之自定义邮件接收主题

VDA标准是德国汽车工业协会(Verband der Automobilindustrie,简称VDA)制定的一系列汽车行业标准。这些标准包括了汽车生产、质量管理、供应链管理、环境保护、安全性能等方面的规范和指南。VDA标准通常被德国和国际上的汽车制造商采用&#x…

学习笔记:二分图

二分图 引入 二分图又被称为二部图。 二分图就是可以二分答案的图。 二分图是节点由两个集合组成,且两个集合内部没有边的图。换言之,存在一种方案,将节点划分成满足以上性质的两个集合。 性质 如果两个集合中的点分别染成黑色和白色&am…

目录和文件操作

在自己电脑任一盘符中新建以OS_Test命名的文件夹,并在该文件夹中新建新建3个以.txt,3个 .xlsx为扩展名的文件(文件名由代码随机生成,长度为8,由字母数字组成)。,请写一个程序,删除掉…

stm32的ADC采样率如何通过Time定时器进行控制

ADC采样率是个跟重要的概念. 手册上说可以通过Timer定时器进行触发ADC采样. 可我这边悲剧的是, 无论怎么样. ADC都会进行采样. 而且就算是TIM停掉也是一样会进行采样. 这就让我摸不着头脑了… 我想通过定时器动态更改ADC的采样频率. 结果不随我愿… 这到底是什么问题呢? 一…

el-table(vue2中)滚动条被固定列盖住

一、项目场景: vue2 el-table 二、问题描述 1、现场图片: 2、全局css环境配置了滚动条高度为6px /* 全局滚动条配置 */ ::-webkit-scrollbar {width: 6px;height: 6px; }::-webkit-scrollbar-track {background-color: #f1f1f1; }::-webkit-scrollbar-…

STM32 定时器配置不当导致误差(精度)偏大的问题发现与解决

通用定时器TIM2/3/4/5,PWM输出1Khz的波形 一开始初始化代码如下: void MX_TIM2_Init(void)//1kHz {TIM_ClockConfigTypeDef sClockSourceConfig {0};TIM_MasterConfigTypeDef sMasterConfig {0};TIM_OC_InitTypeDef sConfigOC {0};htim2.Instance T…

AI与Prompt:解锁软件开发团队的魔法咒语,在复杂任务上生成正确率更高的代码

AI与Prompt:解锁软件开发团队的魔法咒语 写在最前面论文:基于ChatGPT的自协作代码生成将团队协作理论应用于代码生成的研究自协作框架原理1、DOL任务分配2、共享黑板协作3、Instance实例化 案例说明简单任务:基本操作,生成的结果1…

【MySQL架构篇】逻辑架构

逻辑架构 文章目录 逻辑架构1. 服务器处理客户端请求2. Connectors3. 第一层:连接层4. 第二层:服务层5. 第三层:存储引擎6. 存储层7. 小结 1. 服务器处理客户端请求 首先 MySQL 是典型的 C/S 架构,即 Client/Server 架构&#xf…

Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能 前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。 实现代码 import tensorflow as tf from sklearn.datasets import…

在CentOS 7中手工打造和运行xml文件配置的Servlet,然后使用curl、浏览器、telnet等三种工具各自测试

下载Openjdk并配置环境变量 https://jdk.java.net/java-se-ri/11-MR2是官网下载Openjdk 11的地方。 sudo wget https://download.java.net/openjdk/jdk11.0.0.1/ri/openjdk-11.0.0.1_linux-x64_bin.tar.gz下载openjdk 11。 sudo mkdir -p /usr/openjdk11创建目录&#xff…

一张图系列 - “kv cache“

我觉得回答这个问题需要知道3个知识点: 1、multi-head-attention是如何计算的?attention的数学公式? kv cache是如何存储和传递的? 2、kv cache 的原理步骤是什么?为什么降低了消耗? 3、kv cache 代码模…

C++:stl中set(multiset)和map(multimap)的介绍和使用

本文主要从概念、常用接口和使用方法方面介绍set(multiset)和map(multimap)。 目录 一、概念介绍 1.关联式容器 2.键值对 3. 树形结构的关联式容器 二、set和multiset 1.set的介绍 2.set使用 1. set模板参数列表 2. set构造 3. set迭代器 4. set容量 5. set修改操…

正则表达式包含数字和字符匹配

至少6位。 pattern : (?.[0-9])(?.[A-Za-z])[0-9A-Za-z]{6,} 正则表达式中的“?”是一个正向预查字符,它的意思是匹配前一个字符出现的最少一次。具体来说,当一个匹配出现时,它会检查前一个字符是否符合要求,如果符合&#xf…