Web端实现可视化预测
基于Flask搭建Web框架,实现HTML登录页面,编写图片上传并预测展示页面。后端实现上一篇文章所训练好的模型,进行前后端交互,选择一张图片,并将预测结果展示到页面上。
文章目录
- Web端实现可视化预测
- 1. 配置Flask环境
- 2. 登录页面
- 3. 预测页面
- 4. 分类模型、权重及类别
- 5. 主程序搭建及运行
- 5.1 主程序
- 5.2 运行
- 6. 项目打包
1. 配置Flask环境
进入上一篇所创建的环境
conda activate pyweb
pip install flask
pip install flask_cors
下载static
文件,并按以下文件夹列表进行放置
链接:https://pan.baidu.com/s/1_c6pIk8RTN-46QFm6W_O8g
提取码:i2kd
2. 登录页面
login.html
,并放置到templates
文件夹里
<!DOCTYPE html>
<html>
<head>
<title>Login</title>
<meta charset="utf-8">
<link href="../static/css/style.css" rel='stylesheet' type='text/css' />
<meta name="viewport" content="width=device-width, initial-scale=1">
<script type="application/x-javascript"> addEventListener("load", function() { setTimeout(hideURLbar, 0); }, false); function hideURLbar(){ window.scrollTo(0,1); } </script>
</head>
<body>
<!-----start-main---->
<div class="main">
<div class="login-form">
<h1>登录</h1>
<div class="head">
<img src="../static/images/user.png" alt=""/>
</div>
<form action="/login" method="get"> <!-- 添加 action 属性并设置为登录路由 -->
<input type="text" class="text" name="username" value="USERNAME" onfocus="this.value = '';" onblur="if (this.value == '') {this.value = 'USERNAME';}" >
<input type="password" name="password" value="Password" onfocus="this.value = '';" onblur="if (this.value == '') {this.value = 'Password';}">
<div class="submit">
<input type="submit" onclick="myFunction()" value="LOGIN" >
</div>
<p><a href="#">Forgot Password ?</a></p>
</form>
</div>
<!--//End-login-form-->
<!-----start-copyright---->
<div class="copy-right">
<p>HHXC浩瀚星辰<a target="_blank" href=""></a></p>
</div>
<!-----//end-copyright---->
</div>
<!-----//end-main---->
<div style="display:none"><script src='http://v7.cnzz.com/stat.php?id=155540&web_id=155540' language='JavaScript' charset='gb2312'></script></div>
</body>
</html>
效果如下:
3. 预测页面
predict.html
,同样放置到templates
文件夹中
<!DOCTYPE html>
<html>
<head>
<title>Web上传图片并进行预测</title>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<script src="{{ url_for('static', filename='js/jquery.min.js') }}"></script>
</head>
<body>
<div style="text-align: center">
<h1 class="hfont">图像分类模型预测及可视化</h1>
</div>
<!--<h3>请选择图片文件:PNG/JPG/JPEG/SVG/GIF</h3>-->
<div style="text-align: left;margin-left:500px;margin-top:100px;" >
<div style="float:left;">
<a href="javascript:;" class="file">选择文件
<input type="file" name="file" id="file0"><br>
</a>
<img src="" id="img0" style="margin-top:20px;width: 35rem;height: 30rem;">
</div>
<div style="float:left;margin-left:50px;">
<select id="modelSelect">
<option value="AlexNet">AlexNet</option>
<option value="MobileNetV2">MobileNetV2</option>
<option value="DenseNet121">DenseNet121</option>
<option value="ResNet34">ResNet34</option>
</select>
<input type="button" id="b0" onclick="test()" value="预测">
<pre id="out" style="width:320px;height:50px;line-height: 50px;margin-top:20px;"></pre>
</div>
</div>
<script type="text/javascript">
{#监听 file0 按钮, 并获取文件的URL #}
$("#file0").change(function(){
var objUrl = getObjectURL(this.files[0]) ;//获取文件信息
console.log("objUrl = "+objUrl);
{#将获取的图片显示到img0中#}
if (objUrl) {
$("#img0").attr("src", objUrl);
}
});
function test() {
var fileobj = $("#file0")[0].files[0];
console.log(fileobj);
var model = $("#modelSelect").val(); // 获取选择的模型名称
var form = new FormData();
form.append("file", fileobj);
form.append("model", model); // 将选择的模型名称添加到form对象中
var out='';
var flower='';
$.ajax({
type: 'POST',
url: "predict",
data: form,
async: false, //同步执行
processData: false, // 告诉jquery要传输data对象
contentType: false, //告诉jquery不需要增加请求头对于contentType的设置
success: function (arg) {
console.log(arg)
out = arg.result;
{#console.log("out:--"+out)#}
},error:function(){
console.log("后台处理错误");
}
});
out.forEach(e=>{
flower+=`<div style="border-bottom: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">${e}</div>`
});
// out.slice(0,1).forEach(e=>{
// flower+=`<div style="border-bottom: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">${e}</div>`
// }); 只显示最大概率的类别
document.getElementById("out").innerHTML=flower;
}
function getObjectURL(file) {
var url = null;
if(window.createObjectURL!=undefined) {
url = window.createObjectURL(file) ;
}else if (window.URL!=undefined) { // mozilla(firefox)
url = window.URL.createObjectURL(file) ;
}else if (window.webkitURL!=undefined) { // webkit or chrome
url = window.webkitURL.createObjectURL(file) ;
}
return url ;
}
</script>
<style>
.hfont{
font-size: 50px; // 大小
font-weight: 400; // 加粗
font-family: Miscrosoft Yahei; // 字体
}
.file {
position: relative;
/*display: inline-block;*/
background: #CCC ;
border: 1px solid #CCC;
padding: 4px 4px;
overflow: hidden;
text-decoration: none;
text-indent: 0;
width:100px;
height:30px;
line-height: 30px;
border-radius: 5px;
color: #333;
font-size: 13px;
}
.file input {
position: absolute;
font-size: 13px;
right: 0;
top: 0;
opacity: 0;
border: 1px solid #333;
padding: 4px 4px;
overflow: hidden;
text-indent: 0;
width:100px;
height:30px;
line-height: 30px;
border-radius: 5px;
color: #FFFFFF;
}
#b0{
background: #1899FF;
border: 1px solid #CCC;
padding: 4px 10px;
overflow: hidden;
text-indent: 0;
width:60px;
height:28px;
line-height: 20px;
border-radius: 5px;
color: #FFFFFF;
font-size: 13px;
}
/*.gradient{*/
/*filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);*/
/*-ms-filter:alpha(opacity=100 finishopacity=50 style=1 startx=0,starty=0,finishx=0,finishy=150) progid:DXImageTransform.Microsoft.gradient(startcolorstr=#fff,endcolorstr=#ccc,gradientType=0);!*IE8*!*/
/*background:#1899FF; !* 一些不支持背景渐变的浏览器 *!*/
/*background:-moz-linear-gradient(top, #fff, #1899FF);*/
/*background:-webkit-gradient(linear, 0 0, 0 bottom, from(#fff), to(#ccc));*/
/*background:-o-linear-gradient(top, #fff, #ccc);*/
/*}*/
</style>
</body>
</html>
效果如下:
4. 分类模型、权重及类别
将上一篇训练好的模型权重,及模型代码和类别文件放置到models
文件夹中
class_indices.json
mobilenet_v2.py
MobileNetV2.pth
如下:
5. 主程序搭建及运行
5.1 主程序
main.py
import os
import io
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.mobilenet_v2 import MobileNetV2
from flask import Flask, jsonify, request, render_template, redirect, url_for, session
from flask_cors import CORS
app = Flask(__name__)
CORS(app) # 解决跨域问题
app.secret_key = 'super_secret_key' # 设置一个密钥用于 session 加密
@app.route("/", methods=["GET", "POST"])
def root():
return render_template("login.html")
@app.route("/login", methods=["GET"])
def login():
print(request.args)
username = request.args.get('username')
password = request.args.get('password')
# 简单验证用户名和密码
if username == "admin" and password == "123456":
return render_template("predict.html")
else:
return "用户名或者密码错误!重新操作!!!"
# select device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class_json_path = "models/class_indices.json"
# load class info
json_file = open(class_json_path, 'rb')
class_indict = json.load(json_file)
def load_model(model_name):
weights_path = "./models/{}.pth".format(model_name)
assert os.path.exists(weights_path), "weights path does not exist..."
assert os.path.exists(class_json_path), "class json path does not exist..."
print(device)
# create model
model = MobileNetV2(num_classes=5).to(device)
# load model weights
model.load_state_dict(torch.load(weights_path, map_location=device))
return model
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
if image.mode != "RGB":
raise ValueError("input file does not RGB image...")
return my_transforms(image).unsqueeze(0).to(device)
def get_prediction(image_bytes, model_name):
try:
model = load_model(model_name)
model.eval()
tensor = transform_image(image_bytes=image_bytes)
outputs = torch.softmax(model.forward(tensor).squeeze(), dim=0)
prediction = outputs.detach().cpu().numpy()
template = "class:{:<15} probability:{:.3f}"
index_pre = [(class_indict[str(index)], float(p)) for index, p in enumerate(prediction)]
# sort probability
index_pre.sort(key=lambda x: x[1], reverse=True)
text = [template.format(k, v) for k, v in index_pre]
return_info = {"result": text}
except Exception as e:
return_info = {"result": [str(e)]}
return return_info
@app.route("/predict", methods=["POST"])
@torch.no_grad()
def predict():
image = request.files["file"]
model_name = request.form.get('model')
img_bytes = image.read()
info = get_prediction(image_bytes=img_bytes, model_name=model_name)
return jsonify(info)
if __name__ == '__main__':
app.run(debug=True)
5.2 运行
python main.py
输入用户名: admin
密码: 123456
注:这个可以在main.py
中进更改,代码中又注释
选择图片
点击预测
预测结果展示
6. 项目打包
PyWeb
整体项目已打包,获取方式如下:关注回复 PyWeb