视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别

目标学习任务

检测出已经分割出的图像的分类

2 使用pytorch

pytorch 非常简单就可以做到训练和加载

2.1 准备数据

在这里插入图片描述
如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别,然后我们在代码中写名字就行

里面我就为了做一个例子,放了两种文件,1 是 卡宴保时捷,2 是工程车,如下图所示
在这里插入图片描述
train.txt 如下图所示
在这里插入图片描述
val.txt 也是同样如此

3 show me the code

3.1 装载数据类

新增一个loaddata.py 文件

import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):
    def __init__(self, root, datatxt, transform=None, target_transform=None):
        super(LoadData, self).__init__()
        file_txt = open(datatxt,'r')
        imgs = []
        for line in file_txt:
            line = line.rstrip()
            words = line.split('|')
            imgs.append((words[0], words[1]))

        self.imgs = imgs
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        random.shuffle(self.imgs)
        name, label = self.imgs[index]
        img = Image.open(self.root + name).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs)

LoadData 类是从torch.util.data.Dataset上继承下来的,需要一个transform类输入,实际上就是转化大小

3.2 网络类

定义一个网络类,只有两个输出

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.pool = nn.MaxPool2d((2, 2))
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(36*36*32, 120)
        self.fc2 = nn.Linear(120, 60)
        self.fc3 = nn.Linear(60, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool1(F.relu(self.conv2(x)))
        x = x.view(-1, 36*36*32)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3.3 主要流程

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Net

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


classes = ['工程车','卡宴']
transform = transforms.Compose(
   [transforms.Resize((152, 152)),transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',
                 datatxt='./data/'+'train.txt',
                 transform=transform)
test_data=LoadData(root ='./data/val/',
                datatxt='./data/'+'val.txt',
                transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)

def imshow(img):
   img = img / 2 + 0.5     # unnormalize
   npimg = img.numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()


net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
   running_loss = 0.0
   for i, data in enumerate(train_loader, 0):
       inputs, labels = data
       optimizer.zero_grad()
       outputs = net(inputs)
       loss = criterion(outputs, labels)
       loss.backward()
       optimizer.step()

       running_loss += loss.item()
       if i % 200 == 0:
           print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 200))
           running_loss = 0.0

print('Finished Training')

PATH = './test.pth'
torch.save(net.state_dict(), PATH)

net = Net()
net.load_state_dict(torch.load(PATH))

correct = 0
total = 0
with torch.no_grad():
   for data in test_loader:
       images, labels = data
       outputs = net(images)
       _, predicted = torch.max(outputs.data, 1)
       total += labels.size(0)
       correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
   100 * correct / total))

在这里插入图片描述
如上图所示,epoch为5时精确度为80%,为10时精确度为100%,各位不要当真,这这是训练集里面的数据集做识别,并不是真的精确度。

3.4 识别代码

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import Net

PATH = './test.pth'
transform = transforms.Compose(
    [transforms.Resize((152, 152)),transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



net = Net()
net.load_state_dict(torch.load(PATH))

img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():
    outputs = net(img)
    _, predicted = torch.max(outputs.data, 1)
    print("the 102 img lable is ",predicted)

如下图所示,102 为卡宴识别为1 正确
在这里插入图片描述

后记

后面我们准备是从视频中传递过来图像进行分类,同时使用我们的工具VT解码视频后进行内存共享来生成图像,而不是从磁盘加载。要用到我们的c++ 解码工具,和pytorch进行交互
以下是第一篇文章:视频与AI,与进程交互(一)
VT 工具准备开源,端午节节后开出来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/33794.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【JavaSE】初步认识

目录 【1】Java语言概述 【1.1】Java是什么 【1.2】Java语言重要性 【1.3】Java语言发展简史 【1.4】Java语言特性 【1.5】 Java开发环境安装 【2】初识Java的main方法 【2.1】main方法示例 【2.2】运行Java程序 【3】注释 【3.1】基本规则 【3.2】注释规范 【4】…

C语言--消失的数字

文章目录 1.法一&#xff1a;映射法2.法二&#xff1a;异或法3.法三&#xff1a;差值法4.法四&#xff1a;排序查找 1.法一&#xff1a;映射法 时间复杂度&#xff1a;O&#xff08;N&#xff09; 空间复杂度&#xff1a;O&#xff08;N&#xff09; #include<stdio.h>…

Tree树形控件做权限时,保持选项联动的同时,解决数据无法回显的问题

项目需求&#xff1a; 要求树形控件要有父子联动&#xff0c;也就是选择父级选项&#xff0c;子级也要选中&#xff0c;那么check-strictly属性就不能设置死,我的是 :check-strictlycheckStrictly,在data中赋值有变量。我之前设置check-strictly&#xff0c;就没了联动效果&…

补码的定义

补码的定义 补码的概念引入 补码的定义 例题

智能相机的功能介绍

智能视觉检测相机主要是应用在工业检测领域图像分析识别、视觉检测判断。相机具有颜色有无判别、颜色面积计算、轮廓查找定位、物体特征灰度匹配、颜色或灰度浓淡检测、物体计数、尺寸测量、条码二维码识别读取、尺寸测量、机械收引导定位、字符识别等功能。相机带有HDMI高清视…

设计模式3:单例模式:volatile关键字能不能解决多线程计数问题?

先说结论不能&#xff1a; 代码实测下&#xff1a; public class Counter {public volatile static int count 0;public static void inc() {//这里延迟1毫秒&#xff0c;使得结果明显try {Thread.sleep(1);} catch (InterruptedException e) {}count;}public static void ma…

6.4.4 观察文件类型:file

如果你想要知道某个文件的基本数据&#xff0c;例如是属于 ASCII 或者是 data 文件&#xff0c;或者是 binary&#xff0c; 且其中有没有使用到动态函数库 &#xff08;share library&#xff09; 等等的信息&#xff0c;就可以利用 file 这个指令来检阅。举例来说&#xff1a;…

Linux vs windows 他们之间的区别

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a;网络豆的主页​​​​​ 目录 前言 一.windows与Linux区别 二.Linux与Windows操作对比 三.Linux与Windows命令 …

如何克服自动化测试中的壁垒和问题?

随着自动化测试技术的快速发展和普及&#xff0c;自动化测试已经成为各个行业广泛应用的重要测试手段。然而&#xff0c;自动化测试中仍然存在壁垒和问题&#xff0c;这些问题可能对测试效果产生影响&#xff0c;甚至会影响整个项目的进程。在本文中&#xff0c;我们将探讨如何…

Mysql批量插入1000条数据

使用mysql的存储过程 1.现有如下一张表&#xff1a;site_row 2.创建存储过程 CREATE PROCEDURE p01 () BEGIN declare i int; set i1;while i<1000 doINSERT INTO site_row(row_id,row_num) VALUES ( i,i);set ii1; end WHILE;END; 3.执行存储过程 CALL p01(); 4.查看效…

UE4/5动画系列(3.通过后期处理动画蓝图的头部朝向Actor,两种方法:1.通过动画层接口的look at方法。2.通过control rig的方法)

目录 蓝图 点积dot Yaw判断 后期处理动画蓝图 动画层接口 ControlRig: 蓝图 首先我们创建一个actor类&#xff0c;这个actor类是我们要看的东西&#xff0c;actor在哪&#xff0c;我们的动物就要看到哪里&#xff08;同样&#xff0c;这个我们也是做一个父类&#xff0…

爆肝整理,性能测试-测试工具选型(各个对比)卷起来...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 性能测试和功能测…

Elasticsearch:跨集群复制应用场景及实操 - Cross Cluster Replication

通过跨集群复制&#xff08;Cross Cluster Replication - CCR&#xff09;&#xff0c;你可以跨集群将索引复制并实现&#xff1a; 在数据中心中断时继续处理搜索请求防止搜索量影响索引吞吐量通过在距用户较近的地理位置处理搜索请求来减少搜索延迟 跨集群复制采用主动 - 被…

2核4G服务器_4M带宽_CPU性能测评_60G系统盘

阿里云2核4G服务器297元一年、4M公网带宽、60G系统盘&#xff0c;阿里云轻量应用服务器2核4G4M带宽配置一年297.98元&#xff0c;2核2G3M带宽轻量服务器一年108元12个月&#xff0c;如下图&#xff1a; 目录 阿里云2核4G4M轻量应用服务器 2核4G服务器限制条件 轻量服务器介…

Web安全——PHP基础

PHP基础 一、PHP简述二、基本语法格式三、数据类型、常量以及字符串四、运算符五、控制语句1、条件控制语句2、循环控制语句3、控制语句使用 六、php数组1、数组的声明2、数组的操作2.1 数组的合拼2.2 填加数组元素2.3 添加到指定位置2.4 删除某一个元素2.5 unset 销毁指定的元…

PyTorch开放神经网络交换(Open Neural Network Exchange)ONNX通用格式模型的熟悉

我们在深度学习中可以发现有很多不同格式的模型文件&#xff0c;比如不同的框架就有各自的文件格式&#xff1a;.model、.h5、.pb、.pkl、.pt、.pth等等&#xff0c;各自有标准就带来互通的不便&#xff0c;所以微软、Meta和亚马逊在内的合作伙伴社区一起搞一个ONNX(Open Neura…

【Spring Cloud系列】-负载均衡(Load Balancer,LB)

【Spring Cloud系列】-负载均衡&#xff08;Load Balancer&#xff0c;LB&#xff09; 文章目录 【Spring Cloud系列】-负载均衡&#xff08;Load Balancer&#xff0c;LB&#xff09;一、什么是负载均衡&#xff08;Load Balancer&#xff0c;LB&#xff09;二、负载均衡的主要…

vue2、vue3分别配置echarts多图表的同步缩放

文章目录 ⭐前言⭐使用dataZoom api实现echart的同步缩放&#x1f496; vue2实现echarts多图表同步缩放&#x1f496; vue3实现echarts多图表同步缩放 ⭐结束 ⭐前言 大家好&#xff01;我是yma16&#xff0c;本文分享在vue2和vue3中配置echarts的多图表同步缩放 背景&#xf…

教你如何使用Nodejs搭建HTTP web服务器并发布上线公网

文章目录 前言1.安装Node.js环境2.创建node.js服务3. 访问node.js 服务4.内网穿透4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5.固定公网地址 转载自内网穿透工具的文章&#xff1a;使用Nodejs搭建HTTP服务&#xff0c;并实现公网远程访问「内网穿透」 前言 Node.js…