基于flask的猫狗图像预测案例

📚博客主页:knighthood2001
公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️如遇文章付费,可先看看我公众号中是否发布免费文章❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!

假设,你有模型,有训练好的模型文件,有模型推理代码,就可以把他放到flask上进行展示。

项目架构

在这里插入图片描述

  • index.html是模板文件
  • app.py是项目运行的入口
  • best_model.pth是训练好的模型参数
  • model.py是神经网络模型,这里采用的是GoogleNet网络。
  • model_reasoning.py是模型推理,通过这里面的代码,我们可以在本地进行猫狗图片的预测。

运行图

在这里插入图片描述

点击选择文件
在这里插入图片描述
图片下面就显示预测结果了。
在这里插入图片描述

项目完整代码与讲解

index.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>图像分类</title>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
        }
        #result {
            margin-top: 10px;
        }
        #preview-image {
            max-width: 400px;
            margin-top: 20px;
        }
    </style>
</head>
<body>
    <h1>图像分类</h1>
    <form id="upload-form" action="/predict" method="post" enctype="multipart/form-data">
        <input type="file" name="file" accept="image/*" onchange="previewImage(event)">
        <input type="submit" value="预测">
    </form>

    <img id="preview-image" src="" alt="">
    <br>
    <div id="result"></div>

    <script>
        document.getElementById('upload-form').addEventListener('submit', async (e) => {
            e.preventDefault();  // 阻止默认的表单提交行为
            const formData = new FormData(); // 创建一个新的FormData对象,用于封装表单数据
            formData.append('file', document.querySelector('input[type=file]').files[0]);  // 添加表单数据
            // 使用fetch API发送POST请求到'/predict'路径,并将formData作为请求体
            const response = await fetch('/predict', {
                method: 'POST',
                body: formData
            });
            // 获取响应的JSON数据
            const result = await response.json();
            // 将预测结果显示在页面上ID为'result'的元素中
            document.getElementById('result').innerText = `预测结果: ${result.prediction}`;
        });

        function previewImage(event) {
            const file = event.target.files[0];  // 获取上传的文件对象
            const reader = new FileReader();  // 创建一个FileReader对象,用于读取文件内容

            // 清空上一次的预测结果
            document.getElementById('result').innerText = '';
            
            // 当文件读取完成后,将文件内容显示在页面上ID为'preview-image'的元素中
            reader.onload = function(event) {
                document.getElementById('preview-image').setAttribute('src', event.target.result);
            }
            // 如果用户选择了文件,则开始读取文件内容
            if (file) {
                reader.readAsDataURL(file); // 将文件读取为DataURL格式,这样可以直接用作img元素的src属性
            }
        }
    </script>
</body>
</html>

前端我练的不多,很多解释已经在代码中讲了。

model.py

这是GoogleNet的网络架构

import torch
from torch import nn
from torchsummary import summary
# 定义一个Inception模块
class Inception(nn.Module):
    def __init__(self, in_channels, c1, c2, c3, c4):  # 这些参数,所在的位置都会发送变化,所有需要这个参数
        super(Inception, self).__init__()
        self.ReLU = nn.ReLU()

        # 路线1,单1×1卷积层
        self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)

        # 路线2,1×1卷积层, 3×3的卷积
        self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)

        # 路线3,1×1卷积层, 5×5的卷积
        self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)

        # 路线4,3×3的最大池化, 1×1的卷积
        self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)
        self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)


    def forward(self, x):
        p1 = self.ReLU(self.p1_1(x))
        p2 = self.ReLU(self.p2_2(self.ReLU(self.p2_1(x))))
        p3 = self.ReLU(self.p3_2(self.ReLU(self.p3_1(x))))
        p4 = self.ReLU(self.p4_2(self.p4_1(x)))
        return torch.cat((p1, p2, p3, p4), dim=1)

class GoogLeNet(nn.Module):
    def __init__(self, Inception, in_channels, out_channels):
        super(GoogLeNet, self).__init__()
        self.b1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b3 = nn.Sequential(
            Inception(192, 64, (96, 128), (16, 32), 32),
            Inception(256, 128, (128, 192), (32, 96), 64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b4 = nn.Sequential(
            Inception(480, 192, (96, 208), (16, 48), 64),
            Inception(512, 160, (112, 224), (24, 64), 64),
            Inception(512, 128, (128, 256), (24, 64), 64),
            Inception(512, 112, (128, 288), (32, 64), 64),
            Inception(528, 256, (160, 320), (32, 128), 128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b5 = nn.Sequential(
            Inception(832, 256, (160, 320), (32, 128), 128),
            Inception(832, 384, (192, 384), (48, 128), 128),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1024, out_channels))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        return x


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GoogLeNet(Inception, 1, 10).to(device)
    print(summary(model, (1, 224, 224)))

model_reasoning.py

import torch
from torchvision import transforms
from model import GoogLeNet, Inception
from PIL import Image

def test_model(model, test_file):
    # 设定测试所用到的设备,有GPU用GPU没有GPU用CPU
    device = "cuda" if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    classes = ['猫', '狗']
    print(classes)
    image = Image.open(test_file)

    # normalize = transforms.Normalize([0.162, 0.151, 0.138], [0.058, 0.052, 0.048])
    # # 定义数据集处理方法变量
    # test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])

    # 定义数据集处理方法变量
    test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    image = test_transform(image)

    # 添加批次维度,变成[1,3,224,224]
    image = image.unsqueeze(0)

    with torch.no_grad():
        model.eval()
        image = image.to(device)  # 图片也要放到设备当中
        output = model(image)
        print(output.tolist())
        pre_lab = torch.argmax(output, dim=1)
        result = pre_lab.item()
    print("预测值:", classes[result])
    return classes[result]

def test_special_model(best_model_file, test_file):
    # 加载模型
    model = GoogLeNet(Inception, in_channels=3, out_channels=2)
    model.load_state_dict(torch.load(best_model_file))

    # 模型的推理判断
    return test_model(model, test_file)

if __name__ == "__main__":
    # # 加载模型
    # model = GoogLeNet(Inception, in_channels=3, out_channels=2)
    # model.load_state_dict(torch.load('best_model.pth'))
    # # 模型的推理判断
    # test_model(model, "test_data/images.jfif")

    test_special_model("best_model.pth", "static/1.jpg")

这段代码与之前的模型推理代码不同的是,我添加了test_special_model函数,方便后续app.py中可以直接调用这个函数进行模型推理。

app.py

import os
from flask import Flask, request, jsonify, render_template

from model_reasoning import test_special_model
from model_reasoning import test_model
app = Flask(__name__)

# 定义路由
@app.route('/')
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # 获取上传的文件
        file = request.files['file']
        if file:
            # 调用模型进行预测
            # # 加载模型
            # model = GoogLeNet(Inception, in_channels=3, out_channels=2)
            # basedir = os.path.abspath(os.path.dirname(__file__))
            #
            # model.load_state_dict(torch.load(basedir + '/best_model.pth'))
            # result = test_model(model, file)

            basedir = os.path.abspath(os.path.dirname(__file__))
            best_model_file = basedir + '/best_model.pth'
            result = test_special_model(best_model_file, file)
            return jsonify({'prediction': result})
        else:
            return jsonify({'error': 'No file found'})


if __name__ == '__main__':
    app.run(debug=True)

如果没有上文中的test_special_model函数,那么这里你就需要

   # 加载模型
   model = GoogLeNet(Inception, in_channels=3, out_channels=2)
   basedir = os.path.abspath(os.path.dirname(__file__))
   
   model.load_state_dict(torch.load(basedir + '/best_model.pth'))
   result = test_model(model, file)

并且还需要导入相应的库。

best_model.pth

最重要的是,你需要训练好的一个模型。

有需要的,可以联系我,我直接把这个项目代码发你。省得你还需要配置项目架构。

小插曲

我为什么会使用绝对路径,因为我在使用相对路径后,代码提示找不到这个路径。

    basedir = os.path.abspath(os.path.dirname(__file__))
    best_model_file = basedir + '/best_model.pth'

然后,我刚刚又试了一下,发现使用相对路径,又可以运行成功了。

真是不可思议(这个小插曲花了我大半个小时)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/784722.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uni-app 封装http请求

1.引言 前面一篇文章写了使用Pinia进行全局状态管理。 这篇文章主要介绍一下封装http请求&#xff0c;发送数据请求到服务端进行数据的获取。 感谢&#xff1a; 1.yudao-mall-uniapp: 芋道商城&#xff0c;基于 Vue Uniapp 实现&#xff0c;支持分销、拼团、砍价、秒杀、优…

2024年6月总结 | 软件开发技术月度回顾(第一期)

最新技术资源&#xff08;建议收藏&#xff09; https://www.grapecity.com.cn/resources/ Hello&#xff0c;大家好啊&#xff01;随着欧洲杯和奥运会的临近&#xff0c;2024 年下半年的序幕也随之拉开。回顾 2024 年上半年的技术圈&#xff0c;我们看到了一系列令人振奋的进展…

ELfK logstash filter模块常用的插件 和ELFK部署

ELK之filter模块常用插件 logstash filter模块常用的插件&#xff1a; filter&#xff1a;表示数据处理层&#xff0c;包括对数据进行格式化处理、数据类型转换、数据过滤等&#xff0c;支持正则表达式 grok 对若干个大文本字段进行再分割成一些小字段 (?<字段名…

51单片机嵌入式开发:5、按键、矩阵按键操作及protues仿真

按键、矩阵按键操作及protues仿真 1 按键介绍1.1 按键种类1.2 按键应用场景 2 按键电路3 按键软件设计3.1 按键实现3.2 按键滤波方法3.3 矩阵按键软件设计3.4 按键Protues 仿真 4 按键操作总结 提示 1 按键介绍 1.1 按键种类 按键是一种用于控制电子设备或电路连接和断开的按…

LLM之RAG实战(四十一)| 使用LLamaIndex和Gemini构建高级搜索引擎

Retriever 是 RAG&#xff08;Retrieval Augmented Generation&#xff09;管道中最重要的部分。在本文中&#xff0c;我们将使用 LlamaIndex 实现一个结合关键字和向量搜索检索器的自定义检索器&#xff0c;并且使用 Gemini大模型来进行多个文档聊天。 通过本文&#xff0c;我…

Face_recognition实现人脸识别

这里写自定义目录标题 欢迎使用Markdown编辑器一、安装人脸识别库face_recognition1.1 安装cmake1.2 安装dlib库1.3 安装face_recognition 二、3个常用的人脸识别案例2.1 识别并绘制人脸框2.2 提取并绘制人脸关键点2.3 人脸匹配及标注 欢迎使用Markdown编辑器 本文基于face_re…

Python 安装Numpy 出现异常信息

文章目录 前言一、包源二、安装完成异常 前言 安装Python Numpy包出现异常问题 Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location. 一、包源 使用默认的包源出现超时异常&#xff0c;改用清华包源 pip …

娱乐圈幕后揭秘孙俪天选打工人

【娱乐圈幕后揭秘&#xff1a;孙俪“天选打工人”背后的热议风暴】在聚光灯下光鲜亮丽的娱乐圈&#xff0c;每一位明星的日常备受瞩目。近日&#xff0c;实力派演员孙俪在社交媒体上分享了一段片场棚拍的趣事&#xff0c;本是无心之举&#xff0c;意外引爆了网络热议的导火索。…

这几类人,千万不要买纯电车

文 | AUTO芯球 作者 | 响铃 纯电车的冤大头真是太多了&#xff0c; 我之前劝过&#xff0c;有些人不适合买纯电车&#xff0c; 你们看&#xff0c;果然吧&#xff0c;麦卡锡最近的一份报告就披露了 去年啊&#xff0c;22%的人在买了电车后后悔了&#xff0c; 这些人说了&a…

面试常考题---128陷阱(详细)

1.问题引入 分别引入了int和Integer变量&#xff0c;并进行比较 int b 128; int b1 128;Integer d 127; Integer d1 127;Integer e 128; Integer e1 128;System.out.println(bb1); System.out.println(dd1); System.out.println(ee1); System.out.println(e.equals(e1)…

kafka系列之offset超强总结及消费后不提交offset情况的分析总结

概述 每当我们调用Kafka的poll()方法或者使用Spring的KafkaListener(其实底层也是poll()方法)注解消费Kafka消息时&#xff0c;它都会返回之前被写入Kafka的记录&#xff0c;即我们组中的消费者还没有读过的记录。 这意味着我们有一种方法可以跟踪该组消费者读取过的记录。 如前…

【力扣高频题】014.最长公共前缀

经常刷算法题的小伙伴对于 “最长”&#xff0c;“公共” 两个词一定不陌生。与此相关的算法题目实在是太多了 &#xff01;&#xff01;&#xff01; 之前的 「动态规划」 专题系列文章中就曾讲解过两道相关的题目&#xff1a;最长公共子序列 和 最长回文子序列 。 关注公众…

跨境电商代购系统与电商平台API结合的化学反应

随着全球化的不断推进和互联网技术的飞速发展&#xff0c;跨境电商已成为国际贸易的重要组成部分。跨境电商代购系统作为连接国内外消费者与商品的桥梁&#xff0c;不仅为消费者提供了更多元化的购物选择&#xff0c;也为商家开辟了更广阔的市场空间。在这一过程中&#xff0c;…

如何将heic转jpg格式?四种图片格式转换方法【附教程】

如何把heic转jpg格式&#xff1f;heic是用于存储静态图像和图形的压缩格式&#xff0c;旨在以更小的文件大小保持高质量的图像。HEIC格式自iOS 11和macOS High Sierra&#xff08;10.13&#xff09;内测开始&#xff0c;被苹果设置为图片存储的默认格式&#xff0c;广泛应用于i…

【VUE基础】VUE3第四节—核心语法之computed、watch、watcheffect

computed 接受一个 getter 函数&#xff0c;返回一个只读的响应式 ref 对象。该 ref 通过 .value 暴露 getter 函数的返回值。它也可以接受一个带有 get 和 set 函数的对象来创建一个可写的 ref 对象。 创建一个只读的计算属性 ref&#xff1a; <template><div cl…

【一次成功】清华大学和智谱AI公司的ChatGLM-4-9B-Chat-1M大模型本地化部署教程

【一次成功】清华大学和智谱AI公司的ChatGLM-4-9B-Chat-1M大模型本地化部署教程 一、环境准备二、ChatGLM-4-9B-Chat-1M简介三、模型下载2.1 安装Git LFS2.2 初始化仓库2.3 同步Git文件2.4 拉取文件2.5 下载完毕 四、python和源码安装与下载4.1 安装python4.2 下载源码4.3 安装…

Monaco 中添加 CodeLens

CodeLens 会在指定代码行上添加一行可点击的文字&#xff0c;点击时可以触发定义的命令&#xff0c;效果如下&#xff1a; 通过调用 API 注册 LensProvider&#xff0c;点击时触发 Command&#xff0c;首先要注册命令&#xff0c;通过 editor.addCommand () 方法进行注册。三个…

韦东山嵌入式linux系列-LED驱动程序

之前学习STM32F103C8T6的时候&#xff0c;学习过对应GPIO的输出&#xff1a; 操作STM32的GPIO需要3个步骤&#xff1a; 使用RCC开启GPIO的时钟、使用GPIO_Init函数初始化GPIO、使用输入/输出函数控制GPIO口。 【STM32】GPIO输出-CSDN博客 这里再看看STM32MP157的GPIO引脚使用…

高通开发系列 - 使用QFIL工具单刷某个镜像文件

By: fulinux E-mail: fulinux@sina.com Blog: https://blog.csdn.net/fulinus 喜欢的盆友欢迎点赞和订阅! 你的喜欢就是我写作的动力! 返回:专栏总目录 目录 背景过程记录背景 有时候设备中刷的是user版本,无法使用fastboot刷单个镜像,这个时候该怎么办呢? 要解决在user…

Websocket在Java中的实践——整合Rabbitmq和STOMP

大纲 Rabbitmq开启STOMP支持 服务端依赖参数参数映射类配置类逻辑处理类 测试测试页面Controller测试案例 在《Websocket在Java中的实践——STOMP通信的最小Demo》一文中&#xff0c;我们使用enableSimpleBroker启用一个内置的内存级消息代理。本文我们将使用Rabbitmq作为消息代…