这里我们需要在火山云语音控制台开通大模型的流式语音对话、获取豆包模型的apiKey,开通语音合成项目。
这里使用的豆包模型是Doubao-lite,延迟会更低一些
配置说明
这里一共有四个文件,分别是主要的fastAPI、LLM、STT、文件
TTS中需要配置
appid = "123" #填写控制台的APPID
token = "XXXX" #填写控制台上的Access Token
cluster = "XXXXX" #填写语音生成的组id
voice_type = "BV034_streaming" #这里是生成声音的类型选择
host = "openspeech.bytedance.com" #无需更改
api_url = f"wss://{host}/api/v1/tts/ws_binary" #无需更改
LLM中配置
# 初始化客户端,传入 API 密钥
self.client = Ark(api_key="XXXX")
在STT的146行中配置
header = {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Access-Key": "XXXXX", #和TTS配置内容相同
"X-Api-App-Key": "123", #和TTS配置内容相同
"X-Api-Request-Id": reqid
}
还有前端HTML的配置中记得根据自己服务的所在ip更改配置
前端测试html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>WebSocket 音频传输测试</title>
<style>
body {
font-family: Arial, sans-serif;
}
#status {
margin-bottom: 10px;
}
#messages {
border: 1px solid #ccc;
height: 200px;
overflow-y: scroll;
padding: 10px;
}
#controls {
margin-top: 10px;
}
#controls button {
margin-right: 5px;
}
#latency {
margin-top: 10px;
font-weight: bold;
}
</style>
</head>
<body>
<h1>WebSocket 音频传输测试</h1>
<!-- 显示当前连接状态 -->
<div id="status">状态:未连接</div>
<!-- 显示日志消息 -->
<div id="messages"></div>
<!-- 控制按钮 -->
<div id="controls">
<button id="startButton">开始录音并发送</button>
<button id="stopButton" disabled>停止录音</button>
</div>
<!-- 延迟显示区域 -->
<div id="latency"></div>
<script>
// WebSocket 服务器地址,请根据实际情况替换
const wsUrl = 'ws://127.0.0.1:8000/ws';
// 全局变量
let socket = null; // WebSocket 实例
const messagesDiv = document.getElementById('messages'); // 日志消息显示区域
const statusDiv = document.getElementById('status'); // 连接状态显示区域
const startButton = document.getElementById('startButton'); // 开始录音按钮
const stopButton = document.getElementById('stopButton'); // 停止录音按钮
let recordingAudioContext; // 音频录制上下文
let audioInput; // 音频输入节点
let processor; // 音频处理节点
// 播放相关变量
let playbackAudioContext;
let playbackQueue = [];
let playbackTime = 0;
let isPlaying = false;
// 延迟测量变量
let overSentTime = null; // 记录发送 'over' 的时间
let latencyMeasured = false; // 标记是否已经测量延迟
/**
* 向日志区域添加消息
* @param {string} message - 要记录的消息
*/
function logMessage(message) {
const p = document.createElement('p');
p.textContent = message;
messagesDiv.appendChild(p);
messagesDiv.scrollTop = messagesDiv.scrollHeight; // 自动滚动到最新消息
}
/**
* 初始化Playback AudioContext
*/
function initializePlayback() {
playbackAudioContext = new (window.AudioContext || window.webkitAudioContext)();
logMessage('Playback AudioContext 已创建');
}
/**
* 解码并添加到播放队列
* @param {ArrayBuffer} data - 接收到的音频数据
*/
function appendToPlaybackQueue(data) {
playbackAudioContext.decodeAudioData(data, (audioBuffer) => {
playbackQueue.push(audioBuffer);
schedulePlayback();
}, (error) => {
logMessage('解码音频数据时出错:' + error);
});
}
/**
* 调度播放队列中的音频缓冲区
*/
function schedulePlayback() {
if (isPlaying) return;
if (playbackQueue.length === 0) return;
// 获取下一个缓冲区
const buffer = playbackQueue.shift();
// 创建一个缓冲源
const source = playbackAudioContext.createBufferSource();
source.buffer = buffer;
source.connect(playbackAudioContext.destination);
// 如果 playbackTime 小于当前时间,则更新为当前时间
if (playbackTime < playbackAudioContext.currentTime) {
playbackTime = playbackAudioContext.currentTime;
}
// 计划在 playbackTime 播放
source.start(playbackTime);
logMessage(`Scheduled buffer to play at ${playbackTime.toFixed(2)}s`);
// 更新 playbackTime
playbackTime += buffer.duration;
// 标记为正在播放
isPlaying = true;
// 当缓冲源播放结束时
source.onended = () => {
isPlaying = false;
// 继续播放队列中的下一个缓冲区
schedulePlayback();
};
}
/**
* 创建并连接 WebSocket
*/
function createWebSocket() {
if (socket !== null && (socket.readyState === WebSocket.OPEN || socket.readyState === WebSocket.CONNECTING)) {
logMessage('WebSocket 已经连接或正在连接中');
return;
}
socket = new WebSocket(wsUrl);
socket.binaryType = 'arraybuffer';
socket.onopen = function () {
statusDiv.textContent = '状态:已连接';
logMessage('WebSocket 连接已打开');
startButton.disabled = false; // 启用开始录音按钮
};
socket.onmessage = function (event) {
// 如果接收到的是字符串且内容为 'over'
if (typeof event.data === 'string' && event.data === 'over') {
logMessage('收到结束信号: over');
// 标记 MediaSource 结束
return;
}
// 如果接收到的是二进制数据(ArrayBuffer)
if (event.data instanceof ArrayBuffer) {
logMessage('接收到音频数据');
// 检查是否已经发送 'over' 并且尚未测量延迟
if (overSentTime !== null && !latencyMeasured) {
let receiveTime = performance.now();
let latency = receiveTime - overSentTime;
logMessage(`延迟时间:${latency.toFixed(2)} 毫秒`);
document.getElementById('latency').textContent = `延迟时间:${latency.toFixed(2)} 毫秒`;
latencyMeasured = true; // 标记为已测量
}
appendToPlaybackQueue(event.data); // 解码并添加到播放队列
}
};
socket.onerror = function (error) {
statusDiv.textContent = '状态:连接错误';
logMessage('WebSocket 发生错误:' + error.message);
};
socket.onclose = function (event) {
// 根据关闭代码判断关闭原因
if (event.code === 1000) { // 正常关闭
statusDiv.textContent = '状态:已断开连接';
logMessage('WebSocket 正常关闭');
} else {
statusDiv.textContent = '状态:连接错误';
logMessage(`WebSocket 关闭,代码:${event.code}, 原因:${event.reason}`);
}
startButton.disabled = false; // 启用开始录音按钮
stopButton.disabled = true; // 禁用停止录音按钮
};
}
/**
* 初始化音频播放
*/
function initializeAudioPlayback() {
initializePlayback();
}
/**
* 开始录音并通过 WebSocket 发送音频数据
*/
function startRecording() {
// 创建并连接 WebSocket
createWebSocket();
// 请求访问麦克风
navigator.mediaDevices.getUserMedia({ audio: true })
.then(function (stream) {
// 创建音频上下文,设置采样率为16000Hz
recordingAudioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
// 创建音频源节点,连接到麦克风输入流
audioInput = recordingAudioContext.createMediaStreamSource(stream);
// 创建脚本处理节点,用于处理音频数据
processor = recordingAudioContext.createScriptProcessor(4096, 1, 1);
// 连接音频节点
audioInput.connect(processor);
processor.connect(recordingAudioContext.destination);
// 当音频处理节点有音频数据可处理时触发
processor.onaudioprocess = function (e) {
const audioData = e.inputBuffer.getChannelData(0); // 获取单声道音频数据
const int16Data = floatTo16BitPCM(audioData); // 将浮点数PCM数据转换为16位PCM
const wavBuffer = encodeWAV(int16Data, recordingAudioContext.sampleRate); // 编码为WAV格式
// 如果WebSocket连接打开,则发送WAV数据
if (socket && socket.readyState === WebSocket.OPEN) {
socket.send(wavBuffer);
}
};
logMessage('开始录音并发送音频数据');
startButton.disabled = true; // 禁用开始录音按钮
stopButton.disabled = false; // 启用停止录音按钮
initializeAudioPlayback(); // 初始化音频播放
})
.catch(function (err) {
// 如果无法访问麦克风,则记录错误消息
logMessage('无法访问麦克风:' + err.message);
});
}
/**
* 停止录音并关闭音频节点
*/
function stopRecording() {
// 断开并释放音频处理节点
if (processor) {
processor.disconnect();
processor = null;
}
// 断开并释放音频输入节点
if (audioInput) {
audioInput.disconnect();
audioInput = null;
}
// 关闭音频上下文
if (recordingAudioContext) {
recordingAudioContext.close();
recordingAudioContext = null;
}
logMessage('停止录音');
startButton.disabled = false; // 启用开始录音按钮
stopButton.disabled = true; // 禁用停止录音按钮
// 通过WebSocket发送结束信号
if (socket && socket.readyState === WebSocket.OPEN) {
socket.send("over"); // 与后端约定的结束信号
// 记录发送 'over' 的时间
overSentTime = performance.now();
latencyMeasured = false;
logMessage('发送结束信号 "over"');
}
}
/**
* 将浮点数PCM数据转换为16位PCM数据
* @param {Float32Array} float32Array - 浮点数PCM数据
* @returns {Int16Array} 16位PCM数据
*/
function floatTo16BitPCM(float32Array) {
const int16Array = new Int16Array(float32Array.length);
for (let i = 0; i < float32Array.length; i++) {
// 限制值在[-1, 1]范围内
let s = Math.max(-1, Math.min(1, float32Array[i]));
// 转换为16位整数
int16Array[i] = s < 0 ? s * 0x8000 : s * 0x7FFF;
}
return int16Array;
}
/**
* 编码PCM数据为WAV格式
* @param {Int16Array} samples - 16位PCM数据
* @param {number} sampleRate - 采样率
* @returns {ArrayBuffer} WAV格式数据
*/
function encodeWAV(samples, sampleRate) {
const buffer = new ArrayBuffer(44 + samples.length * 2); // WAV头部44字节 + PCM数据
const view = new DataView(buffer);
/* RIFF identifier */
writeString(view, 0, 'RIFF');
/* 文件长度 */
view.setUint32(4, 36 + samples.length * 2, true);
/* RIFF类型 */
writeString(view, 8, 'WAVE');
/* 格式块标识符 */
writeString(view, 12, 'fmt ');
/* 格式块长度 */
view.setUint32(16, 16, true);
/* 音频格式(1为PCM) */
view.setUint16(20, 1, true);
/* 声道数(1为单声道) */
view.setUint16(22, 1, true);
/* 采样率 */
view.setUint32(24, sampleRate, true);
/* 字节率(采样率 * 声道数 * 每个样本的字节数) */
view.setUint32(28, sampleRate * 2, true);
/* 块对齐(声道数 * 每个样本的字节数) */
view.setUint16(32, 2, true);
/* 每个样本的位数 */
view.setUint16(34, 16, true);
/* 数据块标识符 */
writeString(view, 36, 'data');
/* 数据块长度 */
view.setUint32(40, samples.length * 2, true);
// 写入PCM采样数据
let offset = 44;
for (let i = 0; i < samples.length; i++, offset += 2) {
view.setInt16(offset, samples[i], true);
}
return buffer;
}
/**
* 将字符串写入DataView
* @param {DataView} view - DataView实例
* @param {number} offset - 写入起始位置
* @param {string} string - 要写入的字符串
*/
function writeString(view, offset, string) {
for (let i = 0; i < string.length; i++) {
view.setUint8(offset + i, string.charCodeAt(i));
}
}
// 事件绑定
/**
* 绑定开始录音按钮的点击事件
*/
startButton.addEventListener('click', function () {
startRecording();
});
/**
* 绑定停止录音按钮的点击事件
*/
stopButton.addEventListener('click', function () {
stopRecording();
});
/**
* 页面加载完成后不再自动连接到WebSocket服务器
* 连接将在用户点击“开始录音并发送”时创建
*/
window.onload = function () {
logMessage('请点击 "开始录音并发送" 按钮以开始录音');
};
</script>
</body>
</html>
后端fastAPI的服务入口
import asyncio
import re
import time
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from loguru import logger
from STT import generate_Ws, segment_data_processor
from LLM import LLMDobaoClient
from TTS import long_sentence,create_tts_ws
router = APIRouter()
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
logger.info("WebSocket connection accepted")
audio_result = ""
audio_ws=None
tts_ws=None
seq = 1 # 将 语音识别序列号seq 初始化放在循环外部
llm_client = LLMDobaoClient()
llm_client.add_system_message("你是豆包,是由字节跳动开发的 AI 人工智能助手")
try:
while True:
message = await websocket.receive()
if 'bytes' in message:
data = message['bytes']
# print("接收到数据的大小:", len(data))
if audio_ws is None:
audio_ws=await generate_Ws()
tts_ws=await create_tts_ws()
if data is not None:
audio_result = await segment_data_processor(audio_ws,data, seq)
if audio_result is not None:
print("识别结果:",audio_result)
seq += 1
elif 'text' in message:
# 大模型交互
llm_client.add_user_message(audio_result)
audio_result = "" # 清空识别结果
#TTS的ws连接
#这里是TTS的语音文件的保存,如果需要请取消下面和TTS中的相关注释
# file_to_save = open("test.mp3", "ab") # 使用追加模式打开文件,以便保存多个段落的音频
file_to_save="123"
result=""
seq=1
for response in llm_client.stream_response():
result += response
if len(result) > 100:
# 查找最接近50个字符的标点符号位置
cut_pos = find_nearest_punctuation(result, 50)
# 拼接缓存并截取到cut_pos位置的文本
full_result = result[:cut_pos+1] # 包含标点符号
print(full_result)
await long_sentence(full_result, file_to_save, tts_ws,websocket) # 处理分割后的文本
# 更新 result 和缓存
result = result[cut_pos+1:] # 剩余未处理的部分
# 处理结束时的剩余缓存
if result:
print(result)
await long_sentence(result, file_to_save, tts_ws,websocket)
await websocket.send_text("over")
print("============ 结束 ==============")
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
if audio_ws is not None:
await audio_ws.close()
await audio_ws.close()
except Exception as e:
logger.error(f"WebSocket Error: {e}")
if audio_ws is not None:
await audio_ws.close()
await tts_ws.close()
await websocket.close()
def find_nearest_punctuation(text, max_length):
"""查找距离max_length最接近的标点符号位置"""
# 使用正则表达式查找所有标点符号
punctuation_matches = [m.start() for m in re.finditer(r'[,。!?;]', text)]
# 如果没有标点符号,返回max_length作为分割点
if not punctuation_matches:
return max_length
# 找到离max_length最近的标点符号
nearest_punctuation = max_length
for pos in punctuation_matches:
if pos <= max_length:
nearest_punctuation = pos
else:
break # 当位置超过max_length时停止遍历
return nearest_punctuation
LLM文件代码
from volcenginesdkarkruntime import Ark
from app.config.settings import settings
class LLMDobaoClient:
def __init__(self):
# 初始化客户端,传入 API 密钥
self.client = Ark(api_key="")
# 存储对话消息的列表
self.messages = []
def add_user_message(self, content):
"""添加一条用户消息到对话历史中"""
self.messages.append({"role": "user", "content": content})
def add_system_message(self,content):
"""添加一条用户消息到对话历史中"""
self.messages.append({"role": "system", "content": content})
def add_assistant_message(self, content):
"""添加一条消息到对话历史中"""
self.messages.append({"role": "assistant", "content": content})
def clear_messages(self):
"""清除所有对话历史中的消息"""
self.messages = []
def print_messages(self):
"""打印对话历史中的消息"""
print(self.messages)
def stream_response(self, model="ep-20241013161850-bnqsx"):
"""基于当前消息流式获取模型的响应,并将完整响应添加到消息中"""
print("----- 流式请求开始 -----")
full_response = ""
stream = self.client.chat.completions.create(
model=model,
messages=self.messages,
stream=True
)
for chunk in stream:
if not chunk.choices:
continue
content = chunk.choices[0].delta.content
full_response += content
yield content
# print(content, end="")
# print() # 在流式输出完成后添加一个换行符
# 将完整响应添加到对话历史中
self.add_assistant_message(full_response)
if __name__ == "__main__":
llm_client = LLMDobaoClient()
# 添加初始消息
llm_client.add_system_message("你是豆包,是由字节跳动开发的 AI 人工智能助手")
llm_client.add_user_message("请你讲个小故事")
# 流式获取响应
for response in llm_client.stream_response():
# 这里可以处理每个响应片段
pass
llm_client.print_messages()
print("\n流式请求完成。")
STT文件代码
import asyncio
import gzip
import json
import uuid
import traceback
import websockets
from app.config.settings import settings
# from settings import settings
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
# Message Type:
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_REQUEST = 0b0010
FULL_SERVER_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
# Message Type Specific Flags
NO_SEQUENCE = 0b0000 # no check sequence
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
NEG_SEQUENCE_1 = 0b0011
# Message Serialization
NO_SERIALIZATION = 0b0000
JSON = 0b0001
# Message Compression
NO_COMPRESSION = 0b0000
GZIP = 0b0001
# 生成请求头
def generate_header(
message_type=FULL_CLIENT_REQUEST,
message_type_specific_flags=NO_SEQUENCE,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00
):
"""
protocol_version(4 bits), header_size(4 bits),
message_type(4 bits), message_type_specific_flags(4 bits)
serialization_method(4 bits) message_compression(4 bits)
reserved (8bits) 保留字段
"""
header = bytearray()
header_size = 1
header.append((PROTOCOL_VERSION << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
return header
# 添加序列号信息
def generate_before_payload(sequence: int):
before_payload = bytearray()
before_payload.extend(sequence.to_bytes(4, 'big', signed=True)) # sequence
return before_payload
# 解析服务器响应
def parse_response(res):
"""
protocol_version(4 bits), header_size(4 bits),
message_type(4 bits), message_type_specific_flags(4 bits)
serialization_method(4 bits) message_compression(4 bits)
reserved (8bits) 保留字段
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
payload 类似与http 请求体
"""
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
reserved = res[3]
header_extensions = res[4:header_size * 4]
payload = res[header_size * 4:]
result = {
'is_last_package': False,
}
payload_msg = None
payload_size = 0
if message_type_specific_flags & 0x01:
# receive frame with sequence
seq = int.from_bytes(payload[:4], "big", signed=True)
result['payload_sequence'] = seq
payload = payload[4:]
if message_type_specific_flags & 0x02:
# receive last package
result['is_last_package'] = True
if message_type == FULL_SERVER_RESPONSE:
payload_size = int.from_bytes(payload[:4], "big", signed=True)
payload_msg = payload[4:]
elif message_type == SERVER_ACK:
seq = int.from_bytes(payload[:4], "big", signed=True)
result['seq'] = seq
if len(payload) >= 8:
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
elif message_type == SERVER_ERROR_RESPONSE:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
if payload_msg is None:
return result
if message_compression == GZIP:
payload_msg = gzip.decompress(payload_msg)
if serialization_method == JSON:
payload_msg = json.loads(str(payload_msg, "utf-8"))
elif serialization_method != NO_SERIALIZATION:
payload_msg = str(payload_msg, "utf-8")
result['payload_msg'] = payload_msg
result['payload_size'] = payload_size
return result
#建立ws连接
async def generate_Ws():
print("开始建立ws连接")
try:
reqid = str(uuid.uuid4())
request_params = {
"user": {
"uid": reqid,
},
"audio": {
'format': settings.format,
"rate": settings.framerate,
"bits": settings.bits,
"channel": settings.nchannels,
"codec": "raw"
}
}
payload_bytes = gzip.compress(json.dumps(request_params).encode('utf-8'))
full_client_request = bytearray(generate_header(message_type_specific_flags=POS_SEQUENCE))
full_client_request.extend(generate_before_payload(sequence=1))
full_client_request.extend(len(payload_bytes).to_bytes(4, 'big'))
full_client_request.extend(payload_bytes)
header = {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Access-Key": "XXXXXXXX",
"X-Api-App-Key": "XXXXXXX",
"X-Api-Request-Id": reqid
}
# 使用 await 获取实际的 ws 对象
ws = await websockets.connect("wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", extra_headers=header, max_size=1000000000)
print("连接成功")
await ws.send(full_client_request)
res = await ws.recv()
result = parse_response(res)
print("******************")
print("sauc result", result)
print("******************")
return ws # 返回 ws 对象
except websockets.exceptions.ConnectionClosedError as e:
print(f"WebSocket connection closed with error: {e}")
except websockets.exceptions.InvalidStatusCode as e:
print(f"WebSocket connection failed with status code: {e.status_code}")
except Exception as e:
print(f"An error occurred: {e}")
print(f"Exception type: {type(e)}")
print("Stack trace:")
traceback.print_exc()
#发送数据
async def segment_data_processor(ws,audio_data, seq):
try:
# 压缩当前的音频数据分段
payload_bytes = gzip.compress(audio_data)
except OSError as e:
print(f"压缩音频数据时出错: {e}")
return None
try:
# 生成音频数据的请求头,如果是最后一段,使用负序列的标志
audio_only_request = bytearray(generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=POS_SEQUENCE))
# if seq == -1:
# audio_only_request = bytearray(generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=NEG_WITH_SEQUENCE))
# 将当前音频段的序列号添加到请求中
audio_only_request.extend(generate_before_payload(sequence=seq))
# 将音频段数据的大小(4字节)附加到请求中
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
# 将压缩后的音频数据附加到请求中并发送
audio_only_request.extend(payload_bytes)
# 发送请求
await ws.send(audio_only_request)
# 接收服务器响应
res = await ws.recv()
# 解析服务器响应
result = parse_response(res)
# json_start_index = audio_result.find(b'{')
# json_data = audio_result[json_start_index:]
# decoded_str = json_data.decode('utf-8')
# parsed_result = json.loads(decoded_str)
return result["payload_msg"]["result"]["text"]
except websockets.exceptions.ConnectionClosedError as e:
print(f"WebSocket 连接关闭,状态码: {e.code}, 原因: {e.reason}")
return None
except websockets.exceptions.WebSocketException as e:
print(f"WebSocket 连接错误: {e}")
return None
except Exception as e:
print(f"处理音频段时发生未知错误: {e}")
return None
#接收数据
async def receive_data(ws):
while True:
res = await ws.recv()
# print(res)
result = parse_response(res)
# print("******************")
print("sauc result", result)
# print("******************")
return result
#在这里创建一个主函数调用generate_Ws
async def main():
ws = await generate_Ws()
if ws is not None:
print("ws is not None")
ws.close()
if __name__ == "__main__":
asyncio.run(main())
TTS文件代码
import asyncio
import websockets
import uuid
import json
import gzip
import copy
MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",
2: "last message from server (seq < 0)", 3: "sequence number < 0"}
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}
appid = "123"
token = "XXXXXX"
cluster = "XXXXXXX"
voice_type = "BV034_streaming"
host = "openspeech.bytedance.com"
api_url = f"wss://{host}/api/v1/tts/ws_binary"
# version: b0001 (4 bits)
# header size: b0001 (4 bits)
# message type: b0001 (Full client request) (4bits)
# message type specific flags: b0000 (none) (4bits)
# message serialization method: b0001 (JSON) (4 bits)
# message compression: b0001 (gzip) (4bits)
# reserved data: 0x00 (1 byte)
default_header = bytearray(b'\x11\x10\x11\x00')
request_json = {
"app": {
"appid": appid,
"token": "access_token",
"cluster": cluster
},
"user": {
"uid": "388808087185088"
},
"audio": {
"voice_type": "xxx",
"encoding": "mp3",
"speed_ratio": 1.0,
"volume_ratio": 1.0,
"pitch_ratio": 1.0,
},
"request": {
"reqid": "xxx",
"text": "字节跳动语音合成。",
"text_type": "plain",
"operation": "xxx"
}
}
# 分割长句子并逐段合成音频
async def long_sentence(text,file,ws,websocket):
# 将长句子分成较短的段落
# segments = [text[i:i+50] for i in range(0, len(text), 50)]
# for i, segment in enumerate(segments):
request_json["request"]["text"] = text
await test_submit(request_json, file,ws,websocket)
async def create_tts_ws():
header = {"Authorization": f"Bearer; {token}"}
ws=await websockets.connect(api_url, extra_headers=header, ping_interval=None)
return ws
# 异步函数,提交文本请求以进行语音合成
async def test_submit(request_json, file,ws,websocket):
submit_request_json = copy.deepcopy(request_json)
submit_request_json["audio"]["voice_type"] = voice_type
submit_request_json["request"]["reqid"] = str(uuid.uuid4())
submit_request_json["request"]["operation"] = "submit"
payload_bytes = str.encode(json.dumps(submit_request_json))
payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line
full_client_request = bytearray(default_header)
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
full_client_request.extend(payload_bytes) # payload
# print("\n------------------------ test 'submit' -------------------------")
# print("request json: ", submit_request_json)
# print("\nrequest bytes: ", full_client_request)
await ws.send(full_client_request)
while True:
res = await ws.recv()
done = await parse_response(res, file,websocket)
if done:
break
# 解析服务器返回的响应消息
async def parse_response(res, file,websocket):
# 解析响应头部的各个字段
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
reserved = res[3]
header_extensions = res[4:header_size*4]
payload = res[header_size*4:]
# print(f" Protocol version: {protocol_version:#x} - version {protocol_version}")
# print(f" Header size: {header_size:#x} - {header_size * 4} bytes ")
# print(f" Message type: {message_type:#x} - {MESSAGE_TYPES[message_type]}")
# print(f" Message type specific flags: {message_type_specific_flags:#x} - {MESSAGE_TYPE_SPECIFIC_FLAGS[message_type_specific_flags]}")
# print(f"Message serialization method: {serialization_method:#x} - {MESSAGE_SERIALIZATION_METHODS[serialization_method]}")
# print(f" Message compression: {message_compression:#x} - {MESSAGE_COMPRESSIONS[message_compression]}")
# print(f" Reserved: {reserved:#04x}")
# if header_size != 1:
# print(f" Header extensions: {header_extensions}")
# 根据消息类型对响应进行处理
if message_type == 0xb: # 处理音频服务器响应
if message_type_specific_flags == 0: # 无序列号作为ACK
# print(" Payload size: 0")
return False
else:
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload = payload[8:]
# print(f" Sequence number: {sequence_number}")
# print(f" Payload size: {payload_size} bytes")
# file.write(payload)
await websocket.send_bytes(payload)
if sequence_number < 0: # 如果序列号为负,表示结束
return True
else:
return False
elif message_type == 0xf: # 处理错误消息
code = int.from_bytes(payload[:4], "big", signed=False)
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
error_msg = payload[8:]
if message_compression == 1:
error_msg = gzip.decompress(error_msg)
error_msg = str(error_msg, "utf-8")
print(f" Error message code: {code}")
print(f" Error message size: {msg_size} bytes")
print(f" Error message: {error_msg}")
return True
elif message_type == 0xc: # 处理前端消息
msg_size = int.from_bytes(payload[:4], "big", signed=False)
payload = payload[4:]
if message_compression == 1:
payload = gzip.decompress(payload)
print(f" Frontend message: {payload}")
else:
print("undefined message type!")
return True
# 主程序入口
if __name__ == '__main__':
loop = asyncio.get_event_loop()
long_text = "这是一个很长的句子,需要分成多个段落来逐步合成语音,以便处理。" # 示例长句子
try:
loop.run_until_complete(test_long_sentence(long_text))
finally:
loop.close()