基于深度学习的婴儿啼哭识别项目详解

基于深度学习的婴儿啼哭识别项目详解

  • 基于深度学习的婴儿啼哭识别项目详解
    • 一、项目背景
      • 1.1 项目背景
      • 1.2 数据说明
    • 二、PaddleSpeech环境准备
    • 三、数据预处理
      • 3.1 数据解压缩
      • 3.2 查看声音文件
      • 3.3 音频文件长度处理
    • 四、自定义数据集与模型训练
      • 4.1 自定义数据集
      • 4.2 模型训练
      • 4.3 模型训练
    • 五、模型测试
    • 六、注意事项

基于深度学习的婴儿啼哭识别项目详解

一、项目背景

婴儿啼哭声是婴儿沟通需求的重要信号,对于父母和护理者而言至关重要。本项目基于PaddleSpeech框架,致力于构建婴儿啼哭识别系统,通过深度学习将啼哭声翻译成成人语言,帮助理解婴儿的需求和状态。
在这里插入图片描述

1.1 项目背景

婴儿啼哭声是一种生物报警器,传递婴儿的生理和心理需求。有效地识别啼哭声有助于提高婴儿护理的效率和质量。

1.2 数据说明

项目使用六类人工添加噪声的哭声作为训练数据集,分别代表不同的婴儿需求,如苏醒、换尿布、要抱抱、饥饿、困乏、不舒服。噪声数据来自Noisex-92标准数据库。

二、PaddleSpeech环境准备

安装PaddleSpeech和PaddleAudio,确保环境准备就绪。

!python -m pip install -q -U pip --user
!pip install paddlespeech paddleaudio -U -q

三、数据预处理

3.1 数据解压缩

解压缩训练数据集,获取音频文件。

!unzip -qoa data/data41960/dddd.zip

3.2 查看声音文件

通过可视化展示音频波形,了解样本数据的特征。

from paddleaudio import load
data, sr = load(file='train/awake/awake_0.wav', mono=True, dtype='float32')  
print('wav shape: {}'.format(data.shape))
print('sample rate: {}'.format(sr))
plt.figure()
plt.plot(data)
plt.show()

3.3 音频文件长度处理

统一音频文件长度,确保训练数据格式一致。

# 音频信息查看
import soundfile as sf
import numpy as np
import librosa

data, samplerate = sf.read('hungry_0.wav')
channels = len(data.shape)
length_s = len(data) / float(samplerate)
format_rate = 16000
print(f"channels: {channels}")
print(f"length_s: {length_s}")
print(f"samplerate: {samplerate}")

四、自定义数据集与模型训练

4.1 自定义数据集

创建自定义数据集类,包含六类婴儿需求的音频文件。

class CustomDataset(AudioClassificationDataset):
    # List all the class labels
    label_list = [
        'awake',
        'diaper',
        'hug',
        'hungry',
        'sleepy',
        'uncomfortable'
    ]

    train_data_dir = './train/'

    def __init__(self, **kwargs):
        files, labels = self._get_data()
        super(CustomDataset, self).__init__(
            files=files, labels=labels, feat_type='raw', **kwargs)

    # 返回音频文件、label值
    def _get_data(self):
        '''
        This method offer information of wave files and labels.
        '''
        files = []
        labels = []

        for i in range(len(self.label_list)):
            single_class_path = os.path.join(self.train_data_dir, self.label_list[i])
            for sound in os.listdir(single_class_path):
                if 'wav' in sound:
                    sound = os.path.join(single_class_path, sound)
                    files.append(sound)
                    labels.append(i)
        return files, labels

4.2 模型训练

选取预训练模型作为特征提取器,构建分类模型进行模型训练。

# 选取cnn14作为 backbone,用于提取音频的特征
from paddlespeech.cls.models import cnn14
backbone = cnn14(pretrained=True, extract_embedding=True)

# 构建分类模型
class SoundClassifier(nn.Layer):
    def __init__(self, backbone, num_class, dropout=0.1):
        super().__init__()
        self.backbone = backbone
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(self.backbone.emb_size, num_class)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.backbone(x)
        x = self.dropout(x)
        logits = self.fc(x)
        return logits

model = SoundClassifier(backbone, num_class=len(train_ds.label_list))

4.3 模型训练

定义优化器和损失函数,进行模型训练。

# 定义优化器和 Loss
optimizer = paddle.optimizer.Adam(learning_rate=1e-4, parameters=model.parameters())
criterion = paddle.nn.loss.CrossEntropyLoss()

# 模型训练
epochs = 20
steps_per_epoch = len(train_loader)
log_freq = 10
eval_freq = 10

for epoch in range(1, epochs + 1):
    model.train()

    avg_loss = 0
    num_corrects = 0
    num_samples = 0
    
    for batch_idx, batch in enumerate(train_loader):
        waveforms, labels = batch
        feats = feature_extractor(waveforms)
        feats = paddle.transpose(feats, [0, 2, 1])  
        logits = model(feats)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        if isinstance(optimizer._learning_rate, paddle.optimizer.lr.LRScheduler):
            optimizer._learning_rate.step()
        optimizer.clear_grad()

        # 计算损失
        avg_loss += loss.numpy()[0]

        # 计算指标
        preds = paddle.argmax(logits, axis=1)
        num_corrects += (preds == labels).numpy().sum()
        num_samples += feats.shape[0]

        if (batch_idx + 1) % log_freq == 0:
            lr = optimizer.get_lr()
            avg_loss /= log_freq
            avg_acc = num_corrects / num_samples



            print_msg = 'Epoch={}/{}, Step={}/{}'.format(
                epoch, epochs, batch_idx + 1, steps_per_epoch)
            print_msg += ' loss={:.4f}'.format(avg_loss)
            print_msg += ' acc={:.4f}'.format(avg_acc)
            print_msg += ' lr={:.6f}'.format(lr)
            logger.train(print_msg)

            avg_loss = 0
            num_corrects = 0
            num_samples = 0

五、模型测试

通过模型对测试音频进行推理,输出对应的婴儿需求概率。

# 模型测试
top_k = 3
wav_file = 'test/test_0.wav'
n_fft = 1024
win_length = 1024
hop_length = 320
f_min = 50.0
f_max = 16000.0

waveform, sr = load(wav_file, sr=sr)
feature_extractor = LogMelSpectrogram(
    sr=sr, 
    n_fft=n_fft, 
    hop_length=hop_length, 
    win_length=win_length, 
    window='hann', 
    f_min=f_min, 
    f_max=f_max, 
    n_mels=64)
feats = feature_extractor(paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))
feats = paddle.transpose(feats, [0, 2, 1])

logits = model(feats)
probs = nn.functional.softmax(logits, axis=1).numpy()

sorted_indices = probs[0].argsort()

msg = f'[{wav_file}]\n'
for idx in sorted_indices[-1:-top_k-1:-1]:
    msg += f'{train_ds.label_list[idx]}: {probs[0][idx]:.5f}\n'
print(msg)

六、注意事项

  1. 自定义数据集格式参考文档;
  2. 统一音频尺寸,确保音频长度和采样频率一致;
  3. 可学习PaddleSpeech课程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/317923.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【JaveWeb教程】(21) MySQL数据库开发之多表设计:一对多、一对一、多对多的表关系 详细代码示例讲解

目录 2. 多表设计2.1 一对多2.1.1 表设计2.1.2 外键约束 2.2 一对一2.3 多对多2.4 案例 2. 多表设计 关于单表的操作(单表的设计、单表的增删改查)我们就已经学习完了。接下来我们就要来学习多表的操作,首先来学习多表的设计。 项目开发中,在进行数据库…

Python基本语法与变量的相关介绍

python基本语法与变量 python语句的缩进 Python代码块使用缩进对齐表示代码逻辑,Python每段代码块缩进的空白数量可以任意,但要确保同段代码块语句必须包含相同的缩进空白数量。建议在代码块的每个缩进层次使用单个制表符或两个空格或四个空格 , 切记不…

Linux系统中的IP地址、主机名、和域名解析

1.IP地址 每一台联网的电脑都会有一个地址,用于和其它计算机进行通讯 IP地址主要有2个版本,V4版本和V6版本(V6很少用,暂不涉及) IPv4版本的地址格式是:a.b.c.d,其中abcd表示0~255的数字&…

pyenv虚拟环境安装和配合pipenv多版本创建

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、下载配置pyenv二、配置多版本虚拟环境总结 前言 最近公司编写了一个自动化用例编写软件,需要适配win7和win10系统,需要同时编译3.8…

leetcode 2114. 句子中的最多单词数

题目: 一个 句子 由一些 单词 以及它们之间的单个空格组成,句子的开头和结尾不会有多余空格。 给你一个字符串数组 sentences ,其中 sentences[i] 表示单个 句子 。 请你返回单个句子里 单词的最多数目 。 解题方法: 1.遍历列表…

JVM,JRE,JDK的区别和联系简洁版

先看图 利用JDK(调用JAVA API)开发JAVA程序后,通过JDK中的编译程序(javac)将我们的文本java文件编译成JAVA字节码,在JRE上运行这些JAVA字节码,JVM解析这些字节码,映射到CPU指令集或…

RAG代码实操之斗气强者萧炎

📑前言 本文主要是【RAG】——RAG代码实操的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 🌄每日一句&#x…

【Linux运维】LVM和RAID学习及实践

LVM和RAID学习及实践 背景LVM简介新加硬盘的操作RAID-磁盘阵列应用场景RAID0RAID1其他结构RAID制作RAID 小结 背景 某台服务器的磁盘管理需要自己动手处理,找了一些资料也踩了一些坑,在这里记录一下,先介绍一下LVM和RAID这两个东西。在计算机…

【爬虫实战】-爬取微博之夜盛典评论,爬取了1.7w条数据

前言: TaoTao之前在前几期推文中发布了一个篇weibo评论的爬虫。主要就是采集评论区的数据,包括评论、评论者ip、评论id、评论者等一些信息。然后有很多的小伙伴对这个代码很感兴趣。TaoTao也都给代码开源了。由于比较匆忙,所以没来得及去讲这…

Open3D 从体素网格构建八叉树(14)

Open3D 从体素网格构建八叉树(14) 一、算法简介二、算法实现1.代码2.效果一、算法简介 上一章介绍从点云构建八叉树,对点云所在体素进行了可视化显示,这里可以对体素构建八叉树,可视化显示八叉树的具体划分结构。 二、算法实现 1.代码 代码如下(示例): import op…

【python】搭配Miniconda使用VSCode

现在的spyder总是运行出错,启动不了,尝试使用VSCode。 一、在VSCode中使用Miniconda管理的Python环境,可以按照以下步骤进行: a. 确保Miniconda环境已经安装并且正确配置。 b. 打开VSCode,安装Python扩展。 打开VS…

用通俗易懂的方式讲解:Stable Diffusion WebUI 从零基础到入门

本文主要介绍 Stable Diffusion WebUI 的实际操作方法,涵盖prompt推导、lora模型、vae模型和controlNet应用等内容,并给出了可操作的文生图、图生图实战示例。适合对Stable Diffusion感兴趣,但又对Stable Diffusion WebUI使用感到困惑的同学。…

GBASE南大通用提问:如果程序检索到 NULL 值,该怎么办?

可在数据库中存储 NULL 值,但编程语言支持的数据类型不识别 NULL 状态。程序必须 采用某种方式来识别 NULL 项,以免将它作为数据来处理。 在 SQL API 中,指示符变量满足此需要。指示符变量是与可能收到 NULL 项的主变量相 关联的一个附加的变…

深度学习笔记(五)——网络优化(1):学习率自调整、激活函数、损失函数、正则化

文中程序以Tensorflow-2.6.0为例 部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。 截图和程序部分引用自北京大学机器学习公开课 通过学习已经掌握了主要的基础函数之后具备了搭建一个网络并使其正常运行的能力,那下一步我们还…

Linux环境之Ubuntu安装Docker流程

今天分享Linux环境之Ubuntu安装docker流程,Docker 是目前非常流行的容器,对其基本掌握很有必要。下面我们通过阿里云镜像的方式安装: 本来今天准备用清华大学镜像安装呢,好像有点问题,于是改成阿里云安装了。清华安装…

《矩阵分析》笔记

来源:【《矩阵分析》期末速成 主讲人:苑长(5小时冲上90)】https://www.bilibili.com/video/BV1A24y1p76q?vd_sourcec4e1c57e5b6ca4824f87e74170ffa64d 这学期考矩阵论,使用教材是《矩阵论简明教程》,因为没…

Linux———ps命令详解

目录 ps 命令("process status" 的缩写。) 常用选项和参数: a:显示所有用户的进程,包括其他用户的进程。​ u:显示详细的进程信息,包括进程的所有者、CPU 使用率、内存使用量等。…

【LabVIEW FPGA入门】模拟输入和模拟输出

1.简单模拟输入和输出测试 1.打开项目,在FPGA终端下面新建一个VI 2.本示例以模拟输入卡和模拟输出卡同时举例。 3.新建一个VI编写程序,同时将卡1的输出连接到卡2的输入使用物理连线。 4.编译并运行程序,观察是否能从通道中采集和输出信号。 5…

【天龙八部】攻略day6

关键字: 灵武、寻宝要求、雁门 1】灵武选择 西凉枫林,锦带,短匕 白溪湖,明镜,双刺 竹海,玉钩,锁甲 2】楼兰寻宝需求 等级80级,40级前6本心法 3】雁门奖励 简单35*4元佑碎金 普…

PyCharm连接服务器 - 1

文章目录 利用PyCharm实现远程开发使用认证代理连接服务器 利用PyCharm实现远程开发 【注】该连接服务器的方法适用于代码在服务器,我们是通过 GateWay 打开远程服务器的代码进行操作。 该功能只有在PyCharm专业版下才可以使用,并且必须是官方的正版许…