通过火山云API来实现:流式大模型语音对话

这里我们需要在火山云语音控制台开通大模型的流式语音对话、获取豆包模型的apiKey,开通语音合成项目。


这里使用的豆包模型是Doubao-lite,延迟会更低一些
配置说明
这里一共有四个文件,分别是主要的fastAPI、LLM、STT、文件
TTS中需要配置

appid = "123"   #填写控制台的APPID
token = "XXXX"  #填写控制台上的Access Token
cluster = "XXXXX"  #填写语音生成的组id
voice_type = "BV034_streaming"   #这里是生成声音的类型选择

host = "openspeech.bytedance.com"  #无需更改
api_url = f"wss://{host}/api/v1/tts/ws_binary" #无需更改

LLM中配置

 # 初始化客户端,传入 API 密钥
   self.client = Ark(api_key="XXXX")

在STT的146行中配置

header = {
            "X-Api-Resource-Id": "volc.bigasr.sauc.duration",
            "X-Api-Access-Key": "XXXXX",  #和TTS配置内容相同
            "X-Api-App-Key": "123",  #和TTS配置内容相同
            "X-Api-Request-Id": reqid
        }

 还有前端HTML的配置中记得根据自己服务的所在ip更改配置
 

前端测试html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>WebSocket 音频传输测试</title>
    <style>
        body {
            font-family: Arial, sans-serif;
        }
        #status {
            margin-bottom: 10px;
        }
        #messages {
            border: 1px solid #ccc;
            height: 200px;
            overflow-y: scroll;
            padding: 10px;
        }
        #controls {
            margin-top: 10px;
        }
        #controls button {
            margin-right: 5px;
        }
        #latency {
            margin-top: 10px;
            font-weight: bold;
        }
    </style>
</head>
<body>

<h1>WebSocket 音频传输测试</h1>

<!-- 显示当前连接状态 -->
<div id="status">状态:未连接</div>

<!-- 显示日志消息 -->
<div id="messages"></div>

<!-- 控制按钮 -->
<div id="controls">
    <button id="startButton">开始录音并发送</button>
    <button id="stopButton" disabled>停止录音</button>
</div>

<!-- 延迟显示区域 -->
<div id="latency"></div>

<script>
    // WebSocket 服务器地址,请根据实际情况替换
    const wsUrl = 'ws://127.0.0.1:8000/ws';

    // 全局变量
    let socket = null; // WebSocket 实例
    const messagesDiv = document.getElementById('messages'); // 日志消息显示区域
    const statusDiv = document.getElementById('status'); // 连接状态显示区域
    const startButton = document.getElementById('startButton'); // 开始录音按钮
    const stopButton = document.getElementById('stopButton'); // 停止录音按钮
    let recordingAudioContext; // 音频录制上下文
    let audioInput; // 音频输入节点
    let processor; // 音频处理节点

    // 播放相关变量
    let playbackAudioContext;
    let playbackQueue = [];
    let playbackTime = 0;
    let isPlaying = false;

    // 延迟测量变量
    let overSentTime = null; // 记录发送 'over' 的时间
    let latencyMeasured = false; // 标记是否已经测量延迟

    /**
     * 向日志区域添加消息
     * @param {string} message - 要记录的消息
     */
    function logMessage(message) {
        const p = document.createElement('p');
        p.textContent = message;
        messagesDiv.appendChild(p);
        messagesDiv.scrollTop = messagesDiv.scrollHeight; // 自动滚动到最新消息
    }

    /**
     * 初始化Playback AudioContext
     */
    function initializePlayback() {
        playbackAudioContext = new (window.AudioContext || window.webkitAudioContext)();
        logMessage('Playback AudioContext 已创建');
    }

    /**
     * 解码并添加到播放队列
     * @param {ArrayBuffer} data - 接收到的音频数据
     */
    function appendToPlaybackQueue(data) {
        playbackAudioContext.decodeAudioData(data, (audioBuffer) => {
            playbackQueue.push(audioBuffer);
            schedulePlayback();
        }, (error) => {
            logMessage('解码音频数据时出错:' + error);
        });
    }

    /**
     * 调度播放队列中的音频缓冲区
     */
    function schedulePlayback() {
        if (isPlaying) return;
        if (playbackQueue.length === 0) return;

        // 获取下一个缓冲区
        const buffer = playbackQueue.shift();

        // 创建一个缓冲源
        const source = playbackAudioContext.createBufferSource();
        source.buffer = buffer;
        source.connect(playbackAudioContext.destination);

        // 如果 playbackTime 小于当前时间,则更新为当前时间
        if (playbackTime < playbackAudioContext.currentTime) {
            playbackTime = playbackAudioContext.currentTime;
        }

        // 计划在 playbackTime 播放
        source.start(playbackTime);
        logMessage(`Scheduled buffer to play at ${playbackTime.toFixed(2)}s`);

        // 更新 playbackTime
        playbackTime += buffer.duration;

        // 标记为正在播放
        isPlaying = true;

        // 当缓冲源播放结束时
        source.onended = () => {
            isPlaying = false;
            // 继续播放队列中的下一个缓冲区
            schedulePlayback();
        };
    }

    /**
     * 创建并连接 WebSocket
     */
    function createWebSocket() {
        if (socket !== null && (socket.readyState === WebSocket.OPEN || socket.readyState === WebSocket.CONNECTING)) {
            logMessage('WebSocket 已经连接或正在连接中');
            return;
        }

        socket = new WebSocket(wsUrl);
        socket.binaryType = 'arraybuffer';

        socket.onopen = function () {
            statusDiv.textContent = '状态:已连接';
            logMessage('WebSocket 连接已打开');
            startButton.disabled = false; // 启用开始录音按钮
        };

        socket.onmessage = function (event) {
            // 如果接收到的是字符串且内容为 'over'
            if (typeof event.data === 'string' && event.data === 'over') {
                logMessage('收到结束信号: over');
                // 标记 MediaSource 结束
                return;
            }

            // 如果接收到的是二进制数据(ArrayBuffer)
            if (event.data instanceof ArrayBuffer) {
                logMessage('接收到音频数据');

                // 检查是否已经发送 'over' 并且尚未测量延迟
                if (overSentTime !== null && !latencyMeasured) {
                    let receiveTime = performance.now();
                    let latency = receiveTime - overSentTime;
                    logMessage(`延迟时间:${latency.toFixed(2)} 毫秒`);
                    document.getElementById('latency').textContent = `延迟时间:${latency.toFixed(2)} 毫秒`;
                    latencyMeasured = true; // 标记为已测量
                }

                appendToPlaybackQueue(event.data); // 解码并添加到播放队列
            }
        };

        socket.onerror = function (error) {
            statusDiv.textContent = '状态:连接错误';
            logMessage('WebSocket 发生错误:' + error.message);
        };

        socket.onclose = function (event) {
            // 根据关闭代码判断关闭原因
            if (event.code === 1000) { // 正常关闭
                statusDiv.textContent = '状态:已断开连接';
                logMessage('WebSocket 正常关闭');
            } else {
                statusDiv.textContent = '状态:连接错误';
                logMessage(`WebSocket 关闭,代码:${event.code}, 原因:${event.reason}`);
            }
            startButton.disabled = false; // 启用开始录音按钮
            stopButton.disabled = true; // 禁用停止录音按钮
        };
    }

    /**
     * 初始化音频播放
     */
    function initializeAudioPlayback() {
        initializePlayback();
    }

    /**
     * 开始录音并通过 WebSocket 发送音频数据
     */
    function startRecording() {
        // 创建并连接 WebSocket
        createWebSocket();

        // 请求访问麦克风
        navigator.mediaDevices.getUserMedia({ audio: true })
            .then(function (stream) {
                // 创建音频上下文,设置采样率为16000Hz
                recordingAudioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });

                // 创建音频源节点,连接到麦克风输入流
                audioInput = recordingAudioContext.createMediaStreamSource(stream);

                // 创建脚本处理节点,用于处理音频数据
                processor = recordingAudioContext.createScriptProcessor(4096, 1, 1);

                // 连接音频节点
                audioInput.connect(processor);
                processor.connect(recordingAudioContext.destination);

                // 当音频处理节点有音频数据可处理时触发
                processor.onaudioprocess = function (e) {
                    const audioData = e.inputBuffer.getChannelData(0); // 获取单声道音频数据
                    const int16Data = floatTo16BitPCM(audioData); // 将浮点数PCM数据转换为16位PCM
                    const wavBuffer = encodeWAV(int16Data, recordingAudioContext.sampleRate); // 编码为WAV格式

                    // 如果WebSocket连接打开,则发送WAV数据
                    if (socket && socket.readyState === WebSocket.OPEN) {
                        socket.send(wavBuffer);
                    }
                };

                logMessage('开始录音并发送音频数据');
                startButton.disabled = true; // 禁用开始录音按钮
                stopButton.disabled = false; // 启用停止录音按钮
                initializeAudioPlayback(); // 初始化音频播放
            })
            .catch(function (err) {
                // 如果无法访问麦克风,则记录错误消息
                logMessage('无法访问麦克风:' + err.message);
            });
    }

    /**
     * 停止录音并关闭音频节点
     */
    function stopRecording() {
        // 断开并释放音频处理节点
        if (processor) {
            processor.disconnect();
            processor = null;
        }

        // 断开并释放音频输入节点
        if (audioInput) {
            audioInput.disconnect();
            audioInput = null;
        }

        // 关闭音频上下文
        if (recordingAudioContext) {
            recordingAudioContext.close();
            recordingAudioContext = null;
        }

        logMessage('停止录音');
        startButton.disabled = false; // 启用开始录音按钮
        stopButton.disabled = true; // 禁用停止录音按钮

        // 通过WebSocket发送结束信号
        if (socket && socket.readyState === WebSocket.OPEN) {
            socket.send("over"); // 与后端约定的结束信号
            // 记录发送 'over' 的时间
            overSentTime = performance.now();
            latencyMeasured = false;
            logMessage('发送结束信号 "over"');
        }
    }

    /**
     * 将浮点数PCM数据转换为16位PCM数据
     * @param {Float32Array} float32Array - 浮点数PCM数据
     * @returns {Int16Array} 16位PCM数据
     */
    function floatTo16BitPCM(float32Array) {
        const int16Array = new Int16Array(float32Array.length);
        for (let i = 0; i < float32Array.length; i++) {
            // 限制值在[-1, 1]范围内
            let s = Math.max(-1, Math.min(1, float32Array[i]));
            // 转换为16位整数
            int16Array[i] = s < 0 ? s * 0x8000 : s * 0x7FFF;
        }
        return int16Array;
    }

    /**
     * 编码PCM数据为WAV格式
     * @param {Int16Array} samples - 16位PCM数据
     * @param {number} sampleRate - 采样率
     * @returns {ArrayBuffer} WAV格式数据
     */
    function encodeWAV(samples, sampleRate) {
        const buffer = new ArrayBuffer(44 + samples.length * 2); // WAV头部44字节 + PCM数据
        const view = new DataView(buffer);

        /* RIFF identifier */
        writeString(view, 0, 'RIFF');
        /* 文件长度 */
        view.setUint32(4, 36 + samples.length * 2, true);
        /* RIFF类型 */
        writeString(view, 8, 'WAVE');
        /* 格式块标识符 */
        writeString(view, 12, 'fmt ');
        /* 格式块长度 */
        view.setUint32(16, 16, true);
        /* 音频格式(1为PCM) */
        view.setUint16(20, 1, true);
        /* 声道数(1为单声道) */
        view.setUint16(22, 1, true);
        /* 采样率 */
        view.setUint32(24, sampleRate, true);
        /* 字节率(采样率 * 声道数 * 每个样本的字节数) */
        view.setUint32(28, sampleRate * 2, true);
        /* 块对齐(声道数 * 每个样本的字节数) */
        view.setUint16(32, 2, true);
        /* 每个样本的位数 */
        view.setUint16(34, 16, true);
        /* 数据块标识符 */
        writeString(view, 36, 'data');
        /* 数据块长度 */
        view.setUint32(40, samples.length * 2, true);

        // 写入PCM采样数据
        let offset = 44;
        for (let i = 0; i < samples.length; i++, offset += 2) {
            view.setInt16(offset, samples[i], true);
        }

        return buffer;
    }

    /**
     * 将字符串写入DataView
     * @param {DataView} view - DataView实例
     * @param {number} offset - 写入起始位置
     * @param {string} string - 要写入的字符串
     */
    function writeString(view, offset, string) {
        for (let i = 0; i < string.length; i++) {
            view.setUint8(offset + i, string.charCodeAt(i));
        }
    }

    // 事件绑定

    /**
     * 绑定开始录音按钮的点击事件
     */
    startButton.addEventListener('click', function () {
        startRecording();
    });

    /**
     * 绑定停止录音按钮的点击事件
     */
    stopButton.addEventListener('click', function () {
        stopRecording();
    });

    /**
     * 页面加载完成后不再自动连接到WebSocket服务器
     * 连接将在用户点击“开始录音并发送”时创建
     */
    window.onload = function () {
        logMessage('请点击 "开始录音并发送" 按钮以开始录音');
    };
</script>

</body>
</html>

后端fastAPI的服务入口 

import asyncio
import re
import time
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from loguru import logger
from STT import generate_Ws, segment_data_processor
from LLM import LLMDobaoClient
from TTS import long_sentence,create_tts_ws


router = APIRouter()


@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    logger.info("WebSocket connection accepted")
    audio_result = ""
    audio_ws=None
    tts_ws=None
    seq = 1  # 将 语音识别序列号seq 初始化放在循环外部
    llm_client = LLMDobaoClient()
    llm_client.add_system_message("你是豆包,是由字节跳动开发的 AI 人工智能助手")
    try:
        while True:
            message = await websocket.receive()
            if 'bytes' in message:
                data = message['bytes']
                # print("接收到数据的大小:", len(data))
                if audio_ws is None:
                   audio_ws=await generate_Ws()
                   tts_ws=await create_tts_ws()
                if data is not None:
                    audio_result = await segment_data_processor(audio_ws,data, seq)
                    if audio_result is not None:
                        print("识别结果:",audio_result)
                    seq += 1

            elif 'text' in message:
                # 大模型交互
                llm_client.add_user_message(audio_result)
                audio_result = "" # 清空识别结果
                #TTS的ws连接
                #这里是TTS的语音文件的保存,如果需要请取消下面和TTS中的相关注释
                # file_to_save = open("test.mp3", "ab")  # 使用追加模式打开文件,以便保存多个段落的音频       
                file_to_save="123"
                result=""
                seq=1
                for response in llm_client.stream_response():
                    result += response
                    if len(result) > 100:
                        # 查找最接近50个字符的标点符号位置
                        cut_pos = find_nearest_punctuation(result, 50)
                        
                        # 拼接缓存并截取到cut_pos位置的文本
                        full_result =  result[:cut_pos+1]  # 包含标点符号
                        print(full_result)
                        await long_sentence(full_result, file_to_save, tts_ws,websocket)  # 处理分割后的文本
                        # 更新 result 和缓存
                        result = result[cut_pos+1:]  # 剩余未处理的部分

                # 处理结束时的剩余缓存
                if result:
                    print(result)
                    await long_sentence(result, file_to_save, tts_ws,websocket)
                await websocket.send_text("over")
                print("============ 结束 ==============")

    except WebSocketDisconnect:
        logger.info("WebSocket disconnected")
        if audio_ws is not None:
            await audio_ws.close()
            await audio_ws.close()
    except Exception as e:
        logger.error(f"WebSocket Error: {e}")
        if audio_ws is not None:
            await audio_ws.close()
            await tts_ws.close()
        await websocket.close()

def find_nearest_punctuation(text, max_length):
    """查找距离max_length最接近的标点符号位置"""
    # 使用正则表达式查找所有标点符号
    punctuation_matches = [m.start() for m in re.finditer(r'[,。!?;]', text)]
    
    # 如果没有标点符号,返回max_length作为分割点
    if not punctuation_matches:
        return max_length
    
    # 找到离max_length最近的标点符号
    nearest_punctuation = max_length
    for pos in punctuation_matches:
        if pos <= max_length:
            nearest_punctuation = pos
        else:
            break  # 当位置超过max_length时停止遍历

    return nearest_punctuation

LLM文件代码
 

from volcenginesdkarkruntime import Ark
from app.config.settings import settings

class LLMDobaoClient:
    def __init__(self):
        # 初始化客户端,传入 API 密钥
        self.client = Ark(api_key="")
        # 存储对话消息的列表
        self.messages = []

    def add_user_message(self,  content):
        """添加一条用户消息到对话历史中"""
        self.messages.append({"role": "user", "content": content})
    
    def add_system_message(self,content):
        """添加一条用户消息到对话历史中"""
        self.messages.append({"role": "system", "content": content})

    def add_assistant_message(self, content):
        """添加一条消息到对话历史中"""
        self.messages.append({"role": "assistant", "content": content})

    def clear_messages(self):
        """清除所有对话历史中的消息"""
        self.messages = []
        
    def print_messages(self):
        """打印对话历史中的消息"""
        print(self.messages)

    
    def stream_response(self, model="ep-20241013161850-bnqsx"):
        """基于当前消息流式获取模型的响应,并将完整响应添加到消息中"""
        print("----- 流式请求开始 -----")
        full_response = ""
        stream = self.client.chat.completions.create(
            model=model,
            messages=self.messages,
            stream=True
        )
        for chunk in stream:
            if not chunk.choices:
                continue
            content = chunk.choices[0].delta.content
            full_response += content
            yield content
            # print(content, end="")
        # print()  # 在流式输出完成后添加一个换行符
        # 将完整响应添加到对话历史中
        self.add_assistant_message(full_response)

if __name__ == "__main__":
    llm_client = LLMDobaoClient()

    # 添加初始消息
    llm_client.add_system_message("你是豆包,是由字节跳动开发的 AI 人工智能助手")
    llm_client.add_user_message("请你讲个小故事")

    # 流式获取响应
    for response in llm_client.stream_response():
        # 这里可以处理每个响应片段
        pass
    llm_client.print_messages()
    print("\n流式请求完成。")

STT文件代码
 

import asyncio
import gzip
import json
import uuid
import traceback
import websockets
from app.config.settings import settings
# from settings import settings

PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001

# Message Type:
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_REQUEST = 0b0010
FULL_SERVER_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111

# Message Type Specific Flags
NO_SEQUENCE = 0b0000  # no check sequence
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
NEG_SEQUENCE_1 = 0b0011

# Message Serialization
NO_SERIALIZATION = 0b0000
JSON = 0b0001

# Message Compression
NO_COMPRESSION = 0b0000
GZIP = 0b0001


# 生成请求头
def generate_header(
        message_type=FULL_CLIENT_REQUEST,
        message_type_specific_flags=NO_SEQUENCE,
        serial_method=JSON,
        compression_type=GZIP,
        reserved_data=0x00
):
    """
    protocol_version(4 bits), header_size(4 bits),
    message_type(4 bits), message_type_specific_flags(4 bits)
    serialization_method(4 bits) message_compression(4 bits)
    reserved (8bits) 保留字段
    """
    header = bytearray()
    header_size = 1
    header.append((PROTOCOL_VERSION << 4) | header_size)
    header.append((message_type << 4) | message_type_specific_flags)
    header.append((serial_method << 4) | compression_type)
    header.append(reserved_data)
    return header

# 添加序列号信息
def generate_before_payload(sequence: int):
    before_payload = bytearray()
    before_payload.extend(sequence.to_bytes(4, 'big', signed=True))  # sequence
    return before_payload

# 解析服务器响应
def parse_response(res):
    """
    protocol_version(4 bits), header_size(4 bits),
    message_type(4 bits), message_type_specific_flags(4 bits)
    serialization_method(4 bits) message_compression(4 bits)
    reserved (8bits) 保留字段
    header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
    payload 类似与http 请求体
    """
    protocol_version = res[0] >> 4
    header_size = res[0] & 0x0f
    message_type = res[1] >> 4
    message_type_specific_flags = res[1] & 0x0f
    serialization_method = res[2] >> 4
    message_compression = res[2] & 0x0f
    reserved = res[3]
    header_extensions = res[4:header_size * 4]
    payload = res[header_size * 4:]
    result = {
        'is_last_package': False,
    }
    payload_msg = None
    payload_size = 0
    if message_type_specific_flags & 0x01:
        # receive frame with sequence
        seq = int.from_bytes(payload[:4], "big", signed=True)
        result['payload_sequence'] = seq
        payload = payload[4:]

    if message_type_specific_flags & 0x02:
        # receive last package
        result['is_last_package'] = True

    if message_type == FULL_SERVER_RESPONSE:
        payload_size = int.from_bytes(payload[:4], "big", signed=True)
        payload_msg = payload[4:]
    elif message_type == SERVER_ACK:
        seq = int.from_bytes(payload[:4], "big", signed=True)
        result['seq'] = seq
        if len(payload) >= 8:
            payload_size = int.from_bytes(payload[4:8], "big", signed=False)
            payload_msg = payload[8:]
    elif message_type == SERVER_ERROR_RESPONSE:
        code = int.from_bytes(payload[:4], "big", signed=False)
        result['code'] = code
        payload_size = int.from_bytes(payload[4:8], "big", signed=False)
        payload_msg = payload[8:]
    if payload_msg is None:
        return result
    if message_compression == GZIP:
        payload_msg = gzip.decompress(payload_msg)
    if serialization_method == JSON:
        payload_msg = json.loads(str(payload_msg, "utf-8"))
    elif serialization_method != NO_SERIALIZATION:
        payload_msg = str(payload_msg, "utf-8")
    result['payload_msg'] = payload_msg
    result['payload_size'] = payload_size
    return result

#建立ws连接
async def generate_Ws():
    print("开始建立ws连接")
    try:
        reqid = str(uuid.uuid4())
        request_params = {
            "user": {
                "uid": reqid,
            },
            "audio": {
                'format': settings.format,
                "rate": settings.framerate,
                "bits": settings.bits,
                "channel": settings.nchannels,
                "codec": "raw"
            }
        }
        payload_bytes = gzip.compress(json.dumps(request_params).encode('utf-8'))
        full_client_request = bytearray(generate_header(message_type_specific_flags=POS_SEQUENCE))
        full_client_request.extend(generate_before_payload(sequence=1))
        full_client_request.extend(len(payload_bytes).to_bytes(4, 'big'))
        full_client_request.extend(payload_bytes)
        header = {
            "X-Api-Resource-Id": "volc.bigasr.sauc.duration",
            "X-Api-Access-Key": "XXXXXXXX",
            "X-Api-App-Key": "XXXXXXX",
            "X-Api-Request-Id": reqid
        }
        # 使用 await 获取实际的 ws 对象
        ws = await websockets.connect("wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", extra_headers=header, max_size=1000000000)
        print("连接成功")
        await ws.send(full_client_request)
        res = await ws.recv()
        result = parse_response(res)
        print("******************")
        print("sauc result", result)
        print("******************")
        return ws  # 返回 ws 对象
    except websockets.exceptions.ConnectionClosedError as e:
        print(f"WebSocket connection closed with error: {e}")
    except websockets.exceptions.InvalidStatusCode as e:
        print(f"WebSocket connection failed with status code: {e.status_code}")
    except Exception as e:
        print(f"An error occurred: {e}")
        print(f"Exception type: {type(e)}")
        print("Stack trace:")
        traceback.print_exc()


#发送数据
async def segment_data_processor(ws,audio_data, seq):
    try:
        # 压缩当前的音频数据分段
        payload_bytes = gzip.compress(audio_data)
    except OSError as e:
        print(f"压缩音频数据时出错: {e}")
        return None
    try:
        # 生成音频数据的请求头,如果是最后一段,使用负序列的标志
        audio_only_request = bytearray(generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=POS_SEQUENCE))
        # if seq == -1:
        #     audio_only_request = bytearray(generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=NEG_WITH_SEQUENCE))
        # 将当前音频段的序列号添加到请求中
        audio_only_request.extend(generate_before_payload(sequence=seq))

        # 将音频段数据的大小(4字节)附加到请求中
        audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big'))

        # 将压缩后的音频数据附加到请求中并发送
        audio_only_request.extend(payload_bytes)

        # 发送请求
        await ws.send(audio_only_request)

        # 接收服务器响应
        res = await ws.recv()
        
        # 解析服务器响应
        result = parse_response(res)
        # json_start_index = audio_result.find(b'{')
        # json_data = audio_result[json_start_index:]
        # decoded_str = json_data.decode('utf-8')
        # parsed_result = json.loads(decoded_str)

        return result["payload_msg"]["result"]["text"]

    except websockets.exceptions.ConnectionClosedError as e:
        print(f"WebSocket 连接关闭,状态码: {e.code}, 原因: {e.reason}")
        return None

    except websockets.exceptions.WebSocketException as e:
        print(f"WebSocket 连接错误: {e}")
        return None

    except Exception as e:
        print(f"处理音频段时发生未知错误: {e}")
        return None


#接收数据
async def receive_data(ws):
    while True:
        res = await ws.recv()
        # print(res)
        result = parse_response(res)
        # print("******************")
        print("sauc result", result)
        # print("******************")
        return result
#在这里创建一个主函数调用generate_Ws
async def main():
    ws = await generate_Ws()
    if ws is not None:
        print("ws is not None")
        ws.close()

if __name__ == "__main__":
    asyncio.run(main())

TTS文件代码
 

import asyncio
import websockets
import uuid
import json
import gzip
import copy

MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",
                               2: "last message from server (seq < 0)", 3: "sequence number < 0"}
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}

appid = "123"
token = "XXXXXX"
cluster = "XXXXXXX"
voice_type = "BV034_streaming"
host = "openspeech.bytedance.com"
api_url = f"wss://{host}/api/v1/tts/ws_binary"

# version: b0001 (4 bits)
# header size: b0001 (4 bits)
# message type: b0001 (Full client request) (4bits)
# message type specific flags: b0000 (none) (4bits)
# message serialization method: b0001 (JSON) (4 bits)
# message compression: b0001 (gzip) (4bits)
# reserved data: 0x00 (1 byte)
default_header = bytearray(b'\x11\x10\x11\x00')

request_json = {
    "app": {
        "appid": appid,
        "token": "access_token",
        "cluster": cluster
    },
    "user": {
        "uid": "388808087185088"
    },
    "audio": {
        "voice_type": "xxx",
        "encoding": "mp3",
        "speed_ratio": 1.0,
        "volume_ratio": 1.0,
        "pitch_ratio": 1.0,
    },
    "request": {
        "reqid": "xxx",
        "text": "字节跳动语音合成。",
        "text_type": "plain",
        "operation": "xxx"
    }
}

# 分割长句子并逐段合成音频
async def long_sentence(text,file,ws,websocket):
    # 将长句子分成较短的段落
    # segments = [text[i:i+50] for i in range(0, len(text), 50)]  
    # for i, segment in enumerate(segments):
        request_json["request"]["text"] = text
        await test_submit(request_json, file,ws,websocket)

async def create_tts_ws():
    header = {"Authorization": f"Bearer; {token}"}
    ws=await websockets.connect(api_url, extra_headers=header, ping_interval=None)
    return ws


# 异步函数,提交文本请求以进行语音合成
async def test_submit(request_json, file,ws,websocket):
    submit_request_json = copy.deepcopy(request_json)
    submit_request_json["audio"]["voice_type"] = voice_type
    submit_request_json["request"]["reqid"] = str(uuid.uuid4())
    submit_request_json["request"]["operation"] = "submit"
    payload_bytes = str.encode(json.dumps(submit_request_json))
    payload_bytes = gzip.compress(payload_bytes)  # if no compression, comment this line
    full_client_request = bytearray(default_header) 
    full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big'))  # payload size(4 bytes)
    full_client_request.extend(payload_bytes)  # payload
    # print("\n------------------------ test 'submit' -------------------------")
    # print("request json: ", submit_request_json)
    # print("\nrequest bytes: ", full_client_request)
    await ws.send(full_client_request)
    while True:
        res = await ws.recv()
        done = await parse_response(res, file,websocket)
        if done:
            break


# 解析服务器返回的响应消息
async def parse_response(res, file,websocket):
    # 解析响应头部的各个字段
    protocol_version = res[0] >> 4
    header_size = res[0] & 0x0f
    message_type = res[1] >> 4
    message_type_specific_flags = res[1] & 0x0f
    serialization_method = res[2] >> 4
    message_compression = res[2] & 0x0f
    reserved = res[3]
    header_extensions = res[4:header_size*4]
    payload = res[header_size*4:]
    # print(f"            Protocol version: {protocol_version:#x} - version {protocol_version}")
    # print(f"                 Header size: {header_size:#x} - {header_size * 4} bytes ")
    # print(f"                Message type: {message_type:#x} - {MESSAGE_TYPES[message_type]}")
    # print(f" Message type specific flags: {message_type_specific_flags:#x} - {MESSAGE_TYPE_SPECIFIC_FLAGS[message_type_specific_flags]}")
    # print(f"Message serialization method: {serialization_method:#x} - {MESSAGE_SERIALIZATION_METHODS[serialization_method]}")
    # print(f"         Message compression: {message_compression:#x} - {MESSAGE_COMPRESSIONS[message_compression]}")
    # print(f"                    Reserved: {reserved:#04x}")
    # if header_size != 1:
        # print(f"           Header extensions: {header_extensions}")
    
    # 根据消息类型对响应进行处理
    if message_type == 0xb:  # 处理音频服务器响应
        if message_type_specific_flags == 0:  # 无序列号作为ACK
            # print("                Payload size: 0")
            return False
        else:
            sequence_number = int.from_bytes(payload[:4], "big", signed=True)
            payload_size = int.from_bytes(payload[4:8], "big", signed=False)
            payload = payload[8:]
            # print(f"             Sequence number: {sequence_number}")
            # print(f"                Payload size: {payload_size} bytes")
        # file.write(payload)
        await websocket.send_bytes(payload)
        if sequence_number < 0:  # 如果序列号为负,表示结束
            return True
        else:
            return False
    elif message_type == 0xf:  # 处理错误消息
        code = int.from_bytes(payload[:4], "big", signed=False)
        msg_size = int.from_bytes(payload[4:8], "big", signed=False)
        error_msg = payload[8:]
        if message_compression == 1:
            error_msg = gzip.decompress(error_msg)
        error_msg = str(error_msg, "utf-8")
        print(f"          Error message code: {code}")
        print(f"          Error message size: {msg_size} bytes")
        print(f"               Error message: {error_msg}")
        return True
    elif message_type == 0xc:  # 处理前端消息
        msg_size = int.from_bytes(payload[:4], "big", signed=False)
        payload = payload[4:]
        if message_compression == 1:
            payload = gzip.decompress(payload)
        print(f"            Frontend message: {payload}")
    else:
        print("undefined message type!")
        return True

# 主程序入口
if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    long_text = "这是一个很长的句子,需要分成多个段落来逐步合成语音,以便处理。"  # 示例长句子
    try:
        loop.run_until_complete(test_long_sentence(long_text))
    finally:
        loop.close()

        


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/901588.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

洛谷 U411986 数的范围(二分模板)

题意&#xff1a;在一个有序序列里面找某个值的初始出现下标和最后出现下标&#xff0c;如果该值不存在&#xff0c;输出-1 -1。 整数二分模板题&#xff0c;该题主要用来练习如何写两种情况下的二分函数的代码模板。 1&#xff09;upper_bound函数&#xff1a;用来寻找边界点A…

鸿蒙是必经之路

少了大嘴的发布会&#xff0c;老实讲有点让人昏昏入睡。关于技术本身的东西&#xff0c;放在后面。 我想想来加把油~ 鸿蒙发布后褒贬不一&#xff0c;其中很多人不太看好鸿蒙&#xff0c;一方面是开源性、一方面是南向北向的利益问题。 不说技术的领先点&#xff0c;我只扯扯…

香橙派5(RK3588)使用npu加速yolov5推理的部署过程

香橙派5使用npu加速yolov5推理的部署过程 硬件环境 部署过程 模型训练(x86主机) 在带nvidia显卡(最好)的主机上进行yolo的配置与训练, 获取最终的best.pt模型文件, 详见另一篇文档 模型转换(x86主机) 下载airockchip提供的yolov5(从pt到onnx) 一定要下这个版本的yolov5, …

【力扣 + 牛客 | SQL题 | 每日三题】大厂笔试真题W1,W4

1. 力扣603&#xff1a;连续空余的座位 1.1 题目&#xff1a; 表: Cinema ------------------- | Column Name | Type | ------------------- | seat_id | int | | free | bool | ------------------- Seat_id 是该表的自动递增主键列。 在 PostgreSQL 中&#…

练习LabVIEW第十九题

学习目标&#xff1a; 刚学了LabVIEW&#xff0c;在网上找了些题&#xff0c;练习一下LabVIEW&#xff0c;有不对不好不足的地方欢迎指正&#xff01; 第十九题&#xff1a; 创建一个程序把另外一个VI的前面板显示在Picture控件中 开始编写&#xff1a; 在前面板放置一个二…

C语言教程——数组(2)

目录 系列文章目录 前言 4、数组作为函数参数 4.1冒泡函数的错误设计 4.2数组名是什么&#xff1f; 总结 前言 我们知道一维数组是连续存放的&#xff0c;随着数组下标的增长&#xff0c;地址是由低到高依次存放的&#xff0c;二维数组&#xff0c;也是在内存里面是连续存放的…

Linux | 配置docker环境时yum一直出错的解决方法

yum出错 Centos 7版本出错问题补充&#xff1a;什么是yumyum 和 apt 有什么区别&#xff1f; Centos 7版本 [rootlocalhost yum.repos.d]# cat /etc/redhat-release CentOS Linux release 7.9.2009 (Core)出错问题 问题1 Could not retrieve mirrorlist http://mirrorlist.ce…

SQLite 3.47.0 发布,大量新功能来袭

SQLite 开发团队于 2024 年 10 月 21 日发布了 SQLite 3.47.0 版本&#xff0c;我们来了解一下新版本的改进功能。 触发器增强 SQLite 3.47.0 版本开始&#xff0c;触发器函数 RAISE() 的 error-message 参数可以支持任意 SQL 表达式。在此之前&#xff0c;该参数只能是字符串…

论1+2+3+4+... = -1/12 的不同算法

我们熟知自然数全加和&#xff0c; 推导过程如下&#xff0c; 这个解法并不难&#xff0c;非常容易看懂&#xff0c;但是并不容易真正理解。正负交错和无穷项计算&#xff0c;只需要保持方程的形态&#xff0c;就可以“预知”结果。但是这到底说的是什么意思&#xff1f;比如和…

Nodejs使用pkg打包为可执行文件

安装pkg npm install -g pkg查看pkg命令 pkg --help修改package.json 新增bin入口配置 {"name": "takescreenshot","version": "1.0.0","bin": "app.js", // 新增bin入口配置"scripts": {"t…

day10:ssh服务-跳板机

一&#xff0c;ssh服务概述 ssh服务概述 ssh&#xff08;Secure Shell&#xff09;是一种用于在不安全网络中进行安全登录、远程执行命令及传输文件的网络协议。它通过加密技术来保证通信的保密性和完整性&#xff0c;主要用于替代不安全的telnet、rlogin、rsh等协议。ssh通常…

计算机视觉-边缘检测实验报告

实验一 边缘检测实验 一、实验目的 1&#xff0e;理解并掌握 Sobel 算子和 Canny 算子的基本原理和应用。 2&#xff0e;学习如何在图像处理中使用这两种算子进行边缘检测。 3&#xff0e;比较 Sobel 算子和 Canny 算子的性能&#xff0c;了解各自的优缺点。 4&#xff0…

【mysql进阶】4-3. 页结构

页面结构 ⻚在MySQL运⾏的过程中起到了⾮常重要的作⽤&#xff0c;为了能发挥更好的性能&#xff0c;可以结合⾃⼰系统的业务场景和数据⼤⼩&#xff0c;对⻚相关的系统变量进⾏调整&#xff0c;⻚的⼤⼩就是⼀个⾮常重要的调整项。同时关于⻚的结构也要有所了解&#xff0c;以…

HTTP协议讲解

前瞻&#xff1a; 认识URL 1.ipport 2.平时上网&#xff0c;就是进程间通信 3.上网行为&#xff0c;1.获取资源 2.上传数据 相当于I/O 4.http协议采用tcp协议 网页 图片 音乐其实都是资源 Http请求 http request Method&#xff1a;Get/Post资源/路径&#xff1a…

MyBatis缓存详解(一级缓存、二级缓存、缓存查询顺序)

固态硬盘缺陷&#xff1a;无法长时间使用&#xff0c;而磁盘只要不消磁&#xff0c;只要不受到磁影响&#xff0c;就可以长期使用&#xff0c;因此绝大多数企业还是使用磁盘来存储数据 像mysql这种关系型数据库中的数据存储在磁盘中&#xff0c;为方便查询&#xff0c;减少系统…

Linux文件类型和根目录结构

Linux文件类型和根目录结构 1.文件类型 字符文件类型说明~普通文件类似于Windows的记事本d目录文件类似于windows文件夹c字符设备文件串行端口设备&#xff0c;顺序读写&#xff0c;键盘b块设备文件可供存储的接口设备&#xff0c;随机读写&#xff0c;硬盘p管道文件用于进程…

工程项目管理软件怎么选?推荐7款实用工具

本文提及的有主流7款工程项目管理系统软件有: 1. Worktile&#xff1b;2. 广联达BIM5D&#xff1b;3. 泛普软件&#xff1b;4. 明源云工程&#xff1b;5. 飞书&#xff1b;6. Smartsheet&#xff1b;7. Procore。 很多工程项目管理人员常常头疼如何有效地管理多个项目&#xff…

保研考研机试攻略:python笔记(1)

&#x1f428;&#x1f428;&#x1f428;宝子们好呀 ~ 我来更新欠大家的python笔记了&#xff0c;从这一篇开始我们来学下python&#xff0c;当然&#xff0c;如果只是想应对机试并且应试语言以C和C为主&#xff0c;那么大家对python了解一点就好&#xff0c;重点可以看高分篇…

【机器学习】——numpy教程

文章目录 1.numpy简介2.初始化numpy3.ndarry的使用3.1numpy的属性3.2numpy的形状3.3ndarray的类型 4numpy生成数组的方法4.1生成0和1数组4.2从现有的数组生成4.3生成固定范围的数组4.4生成随机数组 5.数组的索引、切片6.数组的形状修改7.数组的类型修改8.数组的去重9.ndarray的…