NLP|LSTM+Attention文本分类

目录

一、Attention原理简介

二、LSTM+Attention文本分类实战

1、数据读取及预处理

2、文本序列编码

3、LSTM文本分类

三、划重点

少走10年弯路


        LSTM是一种特殊的循环神经网络(RNN),用于处理序列数据和时间序列数据的建模和预测。而在NLP和时间序列领域上Attention-注意力机制也早已有了大量应用,本文将介绍在LSTM基础上如何添加Attention来优化模型效果。

一、Attention原理简介

        注意力机制通过聚焦于重要的信息,忽略不重要的信息,从而有效地处理输入信息。在神经网络中,注意力机制可以帮助模型更好地关注输入中的重要特征,从而提高模型的性能。

        简单而言,在文本处理任务中,self-attention对每一个词会随机初始化q、k、v三个向量,用每个词的q向量和其他k向量做点积、再归一化得到这个词的权重向量w,用w给v向量加权求和得到z向量(该词attention之后的向量)。再延伸一点,其实可以初始化多组q、k、v矩阵,从而得到多组z矩阵拼接起来(类似于CNN中的多个卷积核、来提取不同信息),再乘上一个矩阵压缩回原来的维度,得到最终的embedding。

        细节原理相对繁琐,推荐大家可以去看一下这篇博客的bert介绍,其中self-attention部分详细且清晰。

https://blog.csdn.net/jiaowoshouzi/article/details/89073944

二、LSTM+Attention文本分类实战

1、数据读取及预处理

import re
import os
from sqlalchemy import create_engine
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve,roc_auc_score
import xgboost as xgb
from xgboost.sklearn import XGBClassifier
import lightgbm as lgb
import matplotlib.pyplot as plt
import gc

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers

# 2、数据读取+预处理
data=pd.read_excel('Inshorts Cleaned Data.xlsx')
 
def data_preprocess(data):
    df=data.drop(['Publish Date','Time ','Headline'],axis=1).copy()
    df.rename(columns={'Source ':'Source'},inplace=True)
    df=df[df.Source.isin(['YouTube','India Today'])].reset_index(drop=True)
    df['y']=np.where(df.Source=='YouTube',1,0)
    df=df.drop(['Source'],axis=1)
    return df
 
df=data.pipe(data_preprocess)
print(df.shape)
df.head()

# 导入英文停用词
from nltk.corpus import stopwords  
from nltk.tokenize import sent_tokenize
stop_english=stopwords.words('english')  
stop_spanish=stopwords.words('spanish') 
stop_english

# 4、文本预处理:处理简写、小写化、去除停用词、词性还原
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords  
from nltk.tokenize import sent_tokenize
import nltk
 
def replace_abbreviation(text):
    
    rep_list=[
        ("it's", "it is"),
        ("i'm", "i am"),
        ("he's", "he is"),
        ("she's", "she is"),
        ("we're", "we are"),
        ("they're", "they are"),
        ("you're", "you are"),
        ("that's", "that is"),
        ("this's", "this is"),
        ("can't", "can not"),
        ("don't", "do not"),
        ("doesn't", "does not"),
        ("we've", "we have"),
        ("i've", " i have"),
        ("isn't", "is not"),
        ("won't", "will not"),
        ("hasn't", "has not"),
        ("wasn't", "was not"),
        ("weren't", "were not"),
        ("let's", "let us"),
        ("didn't", "did not"),
        ("hadn't", "had not"),
        ("waht's", "what is"),
        ("couldn't", "could not"),
        ("you'll", "you will"),
        ("i'll", "i will"),
        ("you've", "you have")
    ]
    result = text.lower()
    for word_replace in rep_list:
        result=result.replace(word_replace[0],word_replace[1])
#     result = result.replace("'s", "")
    
    return result
 
def drop_char(text):
    result=text.lower()
    result=re.sub('[^\w\s]',' ',result) # 去掉标点符号、特殊字符
    result=re.sub('\s+',' ',result) # 多空格处理为单空格
    return result
 
def stemed_words(text,stop_words,lemma):
    
    word_list = [lemma.lemmatize(word, pos='v') for word in text.split() if word not in stop_words]
    result=" ".join(word_list)
    return result
 
def text_preprocess(text_seq):
    stop_words = stopwords.words("english")
    lemma = WordNetLemmatizer()
    
    result=[]
    for text in text_seq:
        if pd.isnull(text):
            result.append(None)
            continue
        text=replace_abbreviation(text)
        text=drop_char(text)
        text=stemed_words(text,stop_words,lemma)
        result.append(text)
    return result
 
df['short']=text_preprocess(df.Short)
df[['Short','short']]

# 5、划分训练、测试集
test_index=list(df.sample(2000).index)
df['label']=np.where(df.index.isin(test_index),'test','train')
df['label'].value_counts()

2、文本序列编码

        按照词频排序,创建长度为6000的高频词词典、来对文本进行序列化编码。

from tensorflow.keras.preprocessing.text import Tokenizer
def word_dict_fit(train_text_list,num_words):
    '''
        train_text_list: ['some thing today ','some thing today2']
    '''
    tok_params={
        'num_words':num_words,  # 词典的长度,仅保留词频top的num_words个词
        'filters':'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
        'lower':True, 
        'split':' ', 
        'char_level':False, 
        'oov_token':None, # 设定词典外的词编码
    }
    tok = Tokenizer(**tok_params) # 分词
    tok.fit_on_texts(train_text_list)
    
    return tok

def word_dict_apply_sequences(tok_model,text_list,len_vec):
    '''
        text_list: ['some thing today ','some thing today2']
    '''
    list_tok = tok_model.texts_to_sequences(text_list) # 编码映射
    
    pad_params={
        'sequences':list_tok,
        'maxlen':len_vec,  # 补全后向量长度
        'padding':'pre', # 'pre' or 'post',在前、在后补全
        'truncating':'pre', # 'pre' or 'post',在前、在后删除长度多余的部分
        'value':0, # 补全0
    }
    seq_tok = pad_sequences(**pad_params) # 补全编码向量,返回二维array
    return seq_tok

num_words,len_vec = 6000,40
tok_model= word_dict_fit(df[df.label=='train'].short,num_words)
tok_train = word_dict_apply_sequences(tok_model,df[df.label=='train'].short,len_vec)
tok_test = word_dict_apply_sequences(tok_model,df[df.label=='test'].short,len_vec)
tok_test

图片

3、LSTM文本分类

        LSTM层的输入是三维张量(batch_size, timesteps, input_dim),所以使用的数据可以是时间序列、也可以是文本数据的embedding;输出设置return_sequences为False,返回尺寸为 (batch_size, units) 的 2D 张量。

'''
LSTM层核心参数
    units:输出维度
    activation:激活函数
    recurrent_activation: RNN循环激活函数
    use_bias: 布尔值,是否使用偏置项
    dropout:0~1之间的浮点数,神经元失活比例
    recurrent_dropout:0~1之间的浮点数,循环状态的神经元失活比例
    return_sequences: True时返回RNN全部输出序列(3D),False时输出序列的最后一个输出(2D)
'''
def init_lstm_model(max_features, embed_size):
    model = Sequential()
    model.add(Embedding(input_dim=max_features, output_dim=embed_size))
    model.add(Bidirectional(LSTM(units=32,activation='relu', recurrent_dropout=0.1)))
    model.add(Dropout(0.25,seed=1))
    model.add(Dense(64))
    model.add(Dropout(0.3,seed=1))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

def model_fit(model, x, y,test_x,test_y):
    return model.fit(x, y, batch_size=100, epochs=2, validation_data=(test_x,test_y))

embed_size = 128
lstm_model=init_lstm_model(num_words, embed_size)
model_train=model_fit(lstm_model,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))
lstm_model.summary()

def model_fit(model, x, y,test_x,test_y):
    return model.fit(x, y, batch_size=100, epochs=2, validation_data=(test_x,test_y))

embed_size = 128
lstm_model=init_lstm_model(num_words, embed_size)
model_train=model_fit(lstm_model,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))
lstm_model.summary()

 

def ks_auc_value(y_value,y_pred):
    fpr,tpr,thresholds= roc_curve(list(y_value),list(y_pred))
    ks=max(tpr-fpr)
    auc= roc_auc_score(list(y_value),list(y_pred))
    return ks,auc

print('train_ks_auc',ks_auc_value(df[df.label=='train'].y,lstm_model.predict(tok_train)))
print('test_ks_auc',ks_auc_value(df[df.label=='test'].y,lstm_model.predict(tok_test)))

'''
    train_ks_auc (0.7223217797649937, 0.922939132379851)
    test_ks_auc (0.7046603930606234, 0.9140880065296716)
'''

4、LSTM+Attention文本分类

        在LSTM层之后添加Attention层优化效果。

from tensorflow.keras.models import Model
def init_lstm_model(max_features, embed_size ,embedding_matrix):
    input_=layers.Input(shape=(40,))
    x=Embedding(input_dim=max_features, output_dim=embed_size,weights=[embedding_matrix],trainable=False)(input_)
    x=Bidirectional(layers.LSTM(units=32,activation='relu', recurrent_dropout=0.1,return_sequences=True))(x)
    x=layers.Attention(40)([x,x])
    x=Dropout(0.25)(x)
    x=layers.Flatten()(x)
    x=Dense(64)(x)
    x=Dropout(0.3)(x)
    x=Dense(1,activation='sigmoid')(x)
    model = Model(inputs=input_, outputs=x)

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

def model_fit(model, x, y,test_x,test_y):
    return model.fit(x, y, batch_size=100, epochs=5, validation_data=(test_x,test_y))

num_words,embed_size = 6000,128
lstm_model2=init_lstm_model(num_words, embed_size ,embedding_matrix)
model_train=model_fit(lstm_model2,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))

print('train_ks_auc',ks_auc_value(df[df.label=='train'].y,gru_model.predict(tok_train)))
print('test_ks_auc',ks_auc_value(df[df.label=='test'].y,gru_model.predict(tok_test)))
'''
    train_ks_auc (0.7126925954159541, 0.9199721561742299)
    test_ks_auc (0.7239373279559567, 0.917086274086166)
'''

三、划重点

少走10年弯路

        关注威信公众号 Python风控模型与数据分析,回复 文本分类5 获取本篇数据及代码

        还有更多理论、代码分享等你来拿

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/309084.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

N卡可以点亮但是A卡无法点亮故障解决记录

目录 关键词平台说明一、故障现象二、排查过程三、根本原因四、措施3.1进入boot后更改CSM启动为下图所示。3.2更改PCIE自动为3.0 关键词 英伟达、AMD、显卡、无法点亮 平台说明 项目Value主板铭瑄 TZZ_H61ITX 2.5GCPU12400f显卡RX6600 RTX4060附加设备PCIE 延长线–显卡 一…

Vue3+Echarts:柱状图的图例(legend)不显示

一、问题 在使用Echarts绘制堆积柱状图的时候&#xff0c;想给柱状图添加图例&#xff0c;但是添加完后&#xff0c;图例不显示。 二、问题及解决 原因 这里图例不显示&#xff0c;是因为legend里面图例的文字内容跟series里每一项的name的内容不一致&#xff0c;必须得两者…

STM32入门教程-2023版【3-4】按键控制制LED

关注 点赞 不错过精彩内容 大家好&#xff0c;我是硬核王同学&#xff0c;最近在做免费的嵌入式知识分享&#xff0c;帮助对嵌入式感兴趣的同学学习嵌入式、做项目、找工作! 这篇文章以项目代码的形式实现GPIO输入 一、按键控制LED &#xff08;1&#xff09;搭建面包板电…

【Spring 篇】JdbcTemplate:轻松驾驭数据库的魔法工具

欢迎来到数据库的奇妙世界&#xff0c;在这里&#xff0c;我们将一同揭开Spring框架中JdbcTemplate的神秘面纱。JdbcTemplate是Spring提供的一个简化数据库操作的工具&#xff0c;它为我们提供了一种轻松驾驭数据库的魔法。本篇博客将详细解释JdbcTemplate的基本使用&#xff0…

跑代码相关 初始环境配置

是看了这个视频&#xff1a;深度学习python环境配置_哔哩哔哩_bilibili 总结的个人笔记 这个是从零开始配python环境的比较好的经验教程&#xff1a; 深度学习python的配置&#xff08;Windows&#xff09; - m1racle - 博客园 (cnblogs.com) 然后关于CUDA和cuDNN&#xff…

基于NFC(215芯片)和酷狗音乐实现NFC音乐墙

前言&#xff1a; 本文方案可以实现直接调起酷狗音乐app自动播放&#xff0c;而非跳转网址 准备工作&#xff1a; nfc toolsnfc task酷狗音乐APPalook浏览器APP 步骤 1.选一首歌 2.右上角选择分享&#xff0c;选择复制链接 复制内容为&#xff1a; 分享胡夏的单曲《爱夏…

U-Net——第一课

一.论文研究背景、成果及意义二、unet论文结构三、算法架构 一.论文研究背景、成果及意义 医学图像分割是医学图像处理与分析领域的复杂而关键的步骤&#xff0c;目的是将医学图像中具有某些特殊含义的部分分割出来&#xff0c;并提取相关特征&#xff0c;为临床诊疗和病理学研…

算法通关村番外篇-LeetCode编程从0到1系列一

大家好我是苏麟 , 今天开始带来LeetCode编程从0到1系列 . 编程基础 0 到 1 , 50 题掌握基础编程能力 大纲 1768.交替合并字符串389. 找不同28. 找出字符串中第一个匹配项的下标283. 移动零66. 加一 1768.交替合并字符串 描述 : 给你两个字符串 word1 和 word2 。请你从 word1…

外汇天眼:CQG 与 TradeStation Securities 的经纪服务集成

TradeStation Securities, Inc.&#xff0c;一家自营的在线股票、ETF、期权和期货交易经纪公司&#xff0c;宣布与CQG合作&#xff0c;CQG是一家为交易员、经纪商、商业套保者和交易所提供高性能技术解决方案的全球供应商&#xff0c;已与TradeStation Securities的经纪服务集成…

MySql -数据库进阶

一、约束 1.外键约束 外键约束概念 让表和表之间产生关系&#xff0c;从而保证数据的准确性&#xff01; 建表时添加外键约束 为什么要有外键约束 -- 创建db2数据库 CREATE DATABASE db2; -- 使用db2数据库 USE db2;-- 创建user用户表 CREATE TABLE USER(id INT PRIMARY KEY …

linux内存浅析

内存的基本概念 操作系统内存非常重要且比较复杂&#xff0c;其中有许多知识点仍然需要掌握才能更进一步分析程序问题。由于是初次全面系统地接触OS内存&#xff0c;目的是为了全面且低层次地理解linux内存相关概念&#xff0c;不会深入其中原理&#xff0c;所以本章也会尽量避…

Java实现CR-图片文字识别功能(超简单)

一.什么是OCR OCR &#xff08;Optical Character Recognition&#xff0c;光学字符识别&#xff09;是指电子设备&#xff08;例如扫描仪或数码相机&#xff09;检查纸上打印的字符&#xff0c;通过检测暗、亮的模式确定其形状&#xff0c;然后用字符识别方法将形状翻译成计算…

【信息论与编码】【北京航空航天大学】实验一、哈夫曼编码【C语言实现】(上)

信息论与编码 实验1 哈夫曼编码 实验报告 一、运行源代码所需要的依赖&#xff1a; 1、硬件支持 Windows 10&#xff0c;64位系统 2、编译器 DEV-Redpanda IDE&#xff0c;小熊猫C 二、算法实现及测试 1、C语言源程序 # define _CRT_SECURE_NO_WARNINGS # include <std…

Qt QCheckBox复选按钮控件

文章目录 1 属性和方法1.1 文本1.2 三态1.3 自动排他1.4 信号和槽 2 实例2.1 布局2.2 代码实现 Qt中的复选按钮类是QCheckBox它和单选按钮很相似&#xff0c;单选按钮常用在“多选一”的场景&#xff0c;而复选按钮常用在"多选多"的场景比如喜欢的水果选项中&#xf…

知识】分享几个摄像头的选型相关知识

【知识】分享几个摄像头的选型相关知识 目录 【知识】分享几个摄像头的选型相关知识一、前言二、正文1、先了解一下监控摄像头的种类1.1、云台型&#xff08;云台型一体摄像机&#xff09;1.2、枪机型&#xff08;枪型摄像机&#xff09;1.3、球机型&#xff08;球型摄像机&…

LeetCode-字符串转换整数atoi(8)

题目描述&#xff1a; 请你来实现一个 myAtoi(string s) 函数&#xff0c;使其能将字符串转换成一个 32 位有符号整数&#xff08;类似 C/C 中的 atoi 函数&#xff09;。 函数 myAtoi(string s) 的算法如下&#xff1a; 读入字符串并丢弃无用的前导空格 检查下一个字符&…

【MySQL】表设计与范式设计

文章目录 一、数据库表设计一对一一对多多对多 二、范式设计第一范式第二范式第三范式BC范式第四范式 一、数据库表设计 一对一 举个例子&#xff0c;比如这里有两张表&#xff0c;用户User表 和 身份信息Info表。 因为一个用户只能有一个身份信息&#xff0c;所以User表和In…

【数学建模】美赛备战笔记 01 美赛指南与竞赛全流程

美赛指南 整篇论文需要在25页内。 六道赛题特点&#xff1a; A、B题涉及到微分方程和物理概念较多&#xff0c;需要一定的专业知识&#xff1b; C题常常涉及到时间序列、机器学习&#xff1b; D题一般是运筹学/网络科学&#xff0c;图论、优化问题&#xff0c;涉及到的概念多…

Day3Qt

1. &#xff08;1&#xff09;完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 …

NAS使用的一些常见命令 ssh sftp 上传 下载 ALL in one

目录 登陆上传/下载内网穿透 登陆 ssh 登陆 ssh usernameserverIP -p portNumsftp 登陆 sftp -P portNum usernameserverIP上传/下载 如ls等&#xff0c;远程服务器操作 如lls等&#xff0c;本机操作&#xff0c;前缀为l 文件 put **** 将本机上文件上传到远程服务器上当…