1. 前言
本文作者以一个前端新手视角,部署自己的神经网络模型作为后端,搭建自己的网站实现应用的实战经历。目前实现的网页应用有:
- AI 语音服务主页
- AI 语音识别
- AI 语音合成
- AI CP号码生成器
欢迎大家试用感受,本文将以博客基于GAN的序列号码预测中训练的pytorch模型为例,进行后端和前端搭建,并构建网站,以下是最终成果展示。
网址:http://www.funsound.cn:5002
2. 相关内容
相关知识点和工具语言如下,希望读者有一定的了解
- 腾讯云服务器
- Html + JavaScript 进行UI设计
- pytorch 模型,onnx 模型导出
- python flask 后端
- 多进程服务实现并发访问
3. 后端工作
3.1 pytorch 模型转 onnx 模型
ONNX 模型是通用的NN格式,采用onnx格式将在服务器cpu推理上速度更快。
# 实例化生成器模型
generator = Generator(input_dim, output_dim)
# 加载训练好的生成器模型权重
generator.load_state_dict(torch.load('models/generator_model.pth'))
generator.eval() # 设置生成器为评估模式
# 导出模型为 ONNX 格式
generator.export_onnx('models/generator_model.onnx', (batch_size, input_dim))
加载onnx模型进行推理
# 加载 ONNX 模型
ort_session = ort.InferenceSession('models/generator_model.onnx')
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
input_noise = np.random.randn(batch_size, input_dim).astype(np.float32)
generated_numbers = ort_session.run([output_name], {input_name: input_noise})[0]
基于onnx推理的CP号码生成算法封装成 【generator. LOTTO_GENERATOR】
3.2 多进程onnx服务
网站访问往往是一个多路并发访问场景,面对众多用户的请求,送入待处理,后端采用多进程进行调度。
if __name__ == "__main__":
from generator import LOTTO_GENERATOR # 我们的gan网络生成算法
# 初始化worker数量
nj = 4
backends = [LOTTO_GENERATOR() for _ in range(nj)]
workers = init_workers(nj=nj, backends=backends)
# 获取并打印所有worker的状态
res = get_workers_state(workers)
print(res)
# 提交100个任务
worker_dir = "demo"
for _ in range(100):
task_id = generate_random_string(length=11) # 生成长度为11的随机字符串作为task_id
task_dir = f"{worker_dir}/{task_id}" # 任务目录
task_inp = generate_random_number_string(length=8) # 生成长度为8的随机数字字符串作为任务输入
task_prgs = f'{task_dir}/progress.txt' # 任务进度文件路径
task_rst = f'{task_dir}/result.txt' # 任务结果文件路径
os.system(f'mkdir -p {task_dir}') # 创建任务目录
params = {
'task_id': task_id,
'task_inp': task_inp,
'task_prgs': task_prgs,
'task_rst': task_rst
}
submit_task(workers=workers, params=params) # 提交任务
time.sleep(0.01) # 等待10毫秒后提交下一个任务
注意代码中多进程服务处理用户请求采用异步方式,用户提交任务后获取task_id, 主进程不会阻塞, 用户根据task_id来追踪自己的任务进度(task_prgs)和结果(task_rst)。
其中调度方式根据子进程的忙碌情况决定,选取最闲的子进程处理用户请求
def submit_task(workers, params: dict):
# 找到任务最少的worker
min_task_worker = min(workers, key=lambda worker: worker.queue.qsize() + worker.working.value)
min_task_worker.queue.put(params) # 将任务提交到最少任务的worker队列中
print(f'assign the task to worker-{min_task_worker.wid}'
3.3 基于Flask搭建http访问接口
我们的后端代码如下,例如我们的ip 是 100.100.123,端口试用5002,则构建了以下http访问接口:
http一般格式: 【http://IP地址:端口/路由】
- http://100.100.123:5002/ 主页
- http://100.100.123:5002/lotto 提交任务 【输入:用户幸运数字,输出:task_id】
- http://100.100.123:5002/get_worker_state 子进程负载状态 【输入:task_id,输出:负载状态】
- http://100.100.123:5002/get_task_prgs 任务完成进度 【输入:task_id,输出:任务进度】
- http://100.100.123:5002/get_task_rst 任务结果 【输入:task_id,输出:任务结果】
from flask import Flask, jsonify,render_template,request
from generator import LOTTO_GENERATOR
from workers import *
import datetime
import json
def get_now_time():
current_time = datetime.datetime.now()
return current_time.strftime('%Y-%m-%d %H:%M:%S')
def task_log(text,log_file="TASK.LOG"):
with open(log_file,'a+') as f:
print(text,file=f)
app = Flask(__name__)
USER_DIR = "user_data"
TASK_MAP = {}
"""
主页
"""
@app.route('/')
def index():
return render_template('index.html')
@app.route('/lotto', methods=['POST'])
def predict():
# 获取客户端信息
ip = request.remote_addr
data = request.get_json()
task_id = ip + "_" + generate_random_string(20)
user_id = ip
task_inp = data['luck_num'] # 8位数字字符串 或者 空字符串
task_dir = "%s/%s/%s" % (USER_DIR, user_id, task_id)
task_prgs = f'{task_dir}/progress.txt' # 任务进度文件路径
task_rst = f'{task_dir}/result.txt' # 任务结果文件路径
task_log(f"TIME:{get_now_time()}")
task_log(f"TASK_ID:{task_id}")
task_log("")
# 生成临时文件
if not os.path.exists(task_dir): os.makedirs(task_dir)
with open(task_prgs,'wt') as f:
json.dump([0.0,'start'],f,indent=4)
TASK_MAP[task_id] = {'task_dir': task_dir,
'task_prgs': task_prgs,
'task_rst': task_rst, }
# 提交任务
params = {
'task_id': task_id,
'task_inp': task_inp,
'task_prgs': task_prgs,
'task_rst': task_rst
}
submit_task(workers=workers, params=params) # 提交任务
return task_id
"""
获得引擎状态
"""
@app.route('/get_worker_state', methods=['GET'])
def get_worker_state():
ip = request.remote_addr
res = {}
for worker in workers:
res[worker.wid] = worker.queue.qsize() + worker.working.value
return res
"""
获得任务进度
"""
@app.route('/get_task_prgs', methods=['POST'])
def get_task_prgs():
ip = request.remote_addr
data = request.get_json()
task_id = data['task_id']
if task_id not in TASK_MAP:
return [-1, 'no such task id']
else:
task_prgs = TASK_MAP[task_id]['task_prgs']
with open(task_prgs, 'rt') as f:
content = json.load(f)
return content
"""
获得任务结果
"""
@app.route('/get_task_rst', methods=['POST'])
def get_task_rst():
ip = request.remote_addr
data = request.get_json()
task_id = data['task_id']
if task_id not in TASK_MAP:
return {}
else:
task_rst = TASK_MAP[task_id]['task_rst']
with open(task_rst, 'rt') as f:
content = json.load(f)
return content
if __name__ == '__main__':
# 初始化worker数量
nj = 4
backends = [LOTTO_GENERATOR() for _ in range(nj)]
workers = init_workers(nj=nj, backends=backends)
app.run(host='0.0.0.0', port=5002)
这样后端就搭建起来啦,这里有4个onnx 模型在后台监听
3.4 python客户端测试
import requests
import time
import json
# 定义服务端地址
server_url = 'http://110.110.123:5002' # 你的服务器和端口
headers = {'Content-Type': 'application/json'}
# 检查服务器 Worker 状态
def check_worker_status():
response = requests.get(f'{server_url}/get_worker_state')
if response.status_code == 200:
worker_status = response.json()
idle_workers = [wid for wid, status in worker_status.items() if status == 0]
if idle_workers:
print("Idle workers available:", idle_workers)
return True
else:
print("No idle workers available.")
return False
else:
print("Failed to get worker status.")
return False
# 提交任务
def submit_task(json_data):
if not check_worker_status():
print("No idle workers available. Task submission failed.")
return None
response = requests.post(f'{server_url}/lotto', json=json_data)
if response.status_code == 200:
task_id = response.text
print(f"Task submitted successfully. Task ID: {task_id}")
return task_id
else:
print("Failed to submit task.")
return None
# 轮询任务进度
def poll_task_progress(task_id):
while True:
json_data = {'task_id':task_id}
response = requests.post(f'{server_url}/get_task_prgs', json=json_data)
if response.status_code == 200:
progress, text = response.json()
print(f"Progress: {progress}, Status: {text}")
if progress == 1:
print("Task completed successfully.")
return True
elif progress == -1:
print(f"Task failed: {text}")
return False
else:
print("Failed to get task progress.")
return False
time.sleep(3) # 每3秒查询一次
# 获取任务结果
def get_task_result(task_id):
json_data = {'task_id':task_id}
response = requests.post(f'{server_url}/get_task_rst', json=json_data)
if response.status_code == 200:
result = response.json()
print("Task result:", result)
return result
else:
print("Failed to get task result.")
return None
# 主函数
def main():
json_data = {'luck_num':""}
# json_data = {'luck_num':"12345678"}
# 提交TTS任务
task_id = submit_task(json_data)
if not task_id:
return
# 轮询任务进度
if poll_task_progress(task_id):
# 获取任务结果
result = get_task_result(task_id)
if __name__ == "__main__":
main()
访问成功
4. 前端工作
4.1 JavaScript 访问 http 函数
JavaScript 调用 http端口如下:
<script>
/* 提交任务 */
function submitTask() {
var button = document.querySelector("button");
button.disabled = true;
button.innerText = "正在生成...";
var useLuckyNumber = document.getElementById("use_lucky_number").checked;
var luckInput = document.getElementById("luck_input");
var luckNum = useLuckyNumber ? luckInput.value : "";
var xhr = new XMLHttpRequest();
xhr.open("POST", "/lotto", true);
xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
xhr.onreadystatechange = function () {
if (xhr.readyState == 4 && xhr.status == 200) {
var taskId = xhr.responseText;
checkProgress(taskId);
} else if (xhr.readyState == 4) {
button.disabled = false;
button.innerText = "生成";
alert("任务提交失败,请重试。");
}
};
xhr.send(JSON.stringify({luck_num: luckNum}));
}
/* 检查任务进度 */
function checkProgress(taskId) {
var xhr = new XMLHttpRequest();
xhr.open("POST", "/get_task_prgs", true);
xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
xhr.onreadystatechange = function () {
if (xhr.readyState == 4 && xhr.status == 200) {
var response = JSON.parse(xhr.responseText);
var progress = response[0];
var status = response[1];
// document.getElementById("progress").innerText = "进度: " + progress + ", 状态: " + status;
if (progress == 1) {
getResult(taskId);
} else if (progress == -1) {
var button = document.querySelector("button");
button.disabled = false;
button.innerText = "生成";
alert("任务失败: " + status);
} else {
setTimeout(function() { checkProgress(taskId); }, 3000);
}
}
};
xhr.send(JSON.stringify({task_id: taskId}));
}
/* 获取任务结果 */
function getResult(taskId) {
var xhr = new XMLHttpRequest();
xhr.open("POST", "/get_task_rst", true);
xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
xhr.onreadystatechange = function () {
if (xhr.readyState == 4 && xhr.status == 200) {
var response = JSON.parse(xhr.responseText);
displayResult(response);
var button = document.querySelector("button");
button.disabled = false;
button.innerText = "生成";
}
};
xhr.send(JSON.stringify({task_id: taskId}));
}
/* 显示任务结果 */
function displayResult(response) {
var frontNumbers = response.front_numbers;
var backNumbers = response.back_numbers;
var resultContainer = document.getElementById("result");
resultContainer.innerHTML = ""; // 清空之前的结果
for (var i = 0; i < frontNumbers.length; i++) {
var lotterySet = document.createElement("div");
lotterySet.className = "lottery-set";
frontNumbers[i].forEach(function(number) {
var numberBall = document.createElement("div");
numberBall.className = "number-ball front-ball";
numberBall.innerText = number;
lotterySet.appendChild(numberBall);
});
backNumbers[i].forEach(function(number) {
var numberBall = document.createElement("div");
numberBall.className = "number-ball back-ball";
numberBall.innerText = number;
lotterySet.appendChild(numberBall);
});
resultContainer.appendChild(lotterySet);
}
}
</script>
4.2 制作网页index.html
注意到Flask提供了网页渲染功能,这样我们可以设计我们的主页
@app.route('/')
def index():
return render_template('index.html')
把上述JS脚本放入index.html 就可以访问后端服务啦,具体html的UI显示,由于代码量很大这里不与展示了,感兴趣同学可以根据上述python客户端的访问逻辑试用GPT为你编写index.html,手机端访问效果如下:
5. 最后
上述是个人搭建自己网站部署AI应用的简单过程,完整源码后期整理上传,欢迎大家留言关注~