基于BILSTM及其他RNN序列模型的人名分类器

数据集Kaggle链接

NameNationalLanguage | Kaggle

数据集分布:

第一列为人名,第二列为国家标签

代码开源地址

Kaggle代码链接

https://www.kaggle.com/code/houjijin/name-nationality-classification

Gitee码云链接

人名国籍分类 Name Nation classification: using BILSTM to predict individual's nationality by their name

github链接

GitHub - Foxbabe1q/Name-Nation-classification: Use BILSTM to do the classification of individuals by their names

RNN序列模型类编写

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F



device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.RNN(input_size, hidden_size, num_layers = num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x, hidden):
        output, hidden = self.rnn(x, hidden)
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        return hidden

class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x, hidden, c):
        output, (hidden, c) = self.rnn(x, (hidden, c))
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden, c

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        return hidden, c0


class SimpleBILSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleBILSTM, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(self.hidden_size*2, self.output_size)

    def forward(self, x, hidden, c):
        output, (hidden, c) = self.rnn(x, (hidden, c))
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden, c

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers*2, batch_size, self.hidden_size, device=device)
        c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, device=device)
        return hidden, c0



class SimpleGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SimpleGRU, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        self.output_size = 18
        self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x, hidden):
        output, hidden = self.rnn(x, hidden)
        output = output[:, -1, :]
        output = self.fc(output)
        return output, hidden

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
        return hidden

注意这里BILSTM类中,由于双向lstm会使用两个lstm模型分别处理前向序列和反向序列,所以在初始化隐藏层和记忆细胞层的时候要设置num_layers为2.

导包

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from RNN_Series1 import SimpleRNN, SimpleLSTM, SimpleGRU, SimpleBILSTM
from torch.utils.data import Dataset, DataLoader
import string
from sklearn.preprocessing import LabelEncoder
import time

字符序列及device定义

letters = string.ascii_letters + " .,;'"
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

数据读取及标签列编码

def load_data():
    data = pd.read_csv('name_classfication.txt', sep='\t', names = ['name', 'country'])
    X = data[['name']]
    lb = LabelEncoder()
    y = data['country']
    y = lb.fit_transform(y)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
    return X_train, X_test, y_train, y_test

数据集定义

class create_dataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.length = len(self.X)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        data = torch.zeros(10, len(letters), dtype = torch.float, device=device)
        for i, letter in enumerate(self.X.iloc[idx,0]):
            if i==10:
                break
            data[i,letters.index(letter)] = 1
        label = torch.tensor(self.y[idx], dtype = torch.long, device=device)
        return data, label

这里使用字符序列进行独热编码,并且由于名字长度不一,所以经过序列长度分布,选取了10作为截断长度.

使用RNN训练

def train_rnn():
    X_train, X_test, y_train, y_test = load_data()
    criterion = nn.CrossEntropyLoss(reduction='sum')
    loss_list = []
    acc_list = []
    val_acc_list = []
    val_loss_list = []
    epochs = 10
    my_dataset = create_dataset(X_train, y_train)
    val_dataset = create_dataset(X_test, y_test)
    my_dataloader = DataLoader(my_dataset, batch_size=64, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=len(y_test), shuffle=True)
    my_rnn = SimpleRNN(len(letters), 128,2)
    my_rnn.to(device)
    optimizer = torch.optim.Adam(my_rnn.parameters(), lr=0.001)
    start_time = time.time()

    for epoch in range(epochs):
        my_rnn.train()
        total_loss = 0
        total_acc = 0
        total_sample = 0
        for i, (X,y) in enumerate(my_dataloader):
            output, hidden = my_rnn(X, my_rnn.init_hidden(batch_size=len(y)))
            total_sample += len(y)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            prediction = output.argmax(dim=1)
            acc_num = torch.sum(prediction == y).item()
            total_acc += acc_num
        loss_list.append(total_loss/total_sample)
        acc_list.append(total_acc/total_sample)

        my_rnn.eval()
        with torch.no_grad():
            for i, (X_val, y_val) in enumerate(val_dataloader):
                output, hidden = my_rnn(X_val, my_rnn.init_hidden(batch_size=len(y_test)))
                loss = criterion(output, y_val)
                prediction = output.argmax(dim=1)
                acc_num = torch.sum(prediction == y_val).item()
                val_acc_list.append(acc_num/len(y_val))
                val_loss_list.append(loss.item()/len(y_val))
                print(f'epoch: {epoch+1}, train_loss: {total_loss/total_sample:.2f}, train_acc: {total_acc/total_sample:.2f}, val_loss: {loss.item()/len(y_val):.2f}, val_acc: {acc_num/len(y_val):.2f}, time: {time.time() - start_time : .2f}')
    torch.save(my_rnn.state_dict(), 'rnn.pt')
    plt.plot(np.arange(1,11),loss_list,label = 'Training Loss')
    plt.plot(np.arange(1,11),val_loss_list,label = 'Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.xticks(np.arange(1,11))
    plt.title('Loss')
    plt.legend()
    plt.savefig('logg.png')
    plt.show()
    plt.plot(np.arange(1,11),acc_list,label = 'Training Accuracy')
    plt.plot(np.arange(1,11),val_acc_list,label = 'Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(1,11))
    plt.title('Accuracy')
    plt.legend()
    plt.savefig('accuracy.png')
    plt.show()

使用BILSTM训练

def train_bilstm():
    X_train, X_test, y_train, y_test = load_data()
    criterion = nn.CrossEntropyLoss(reduction='sum')
    loss_list = []
    acc_list = []
    val_acc_list = []
    val_loss_list = []
    epochs = 10
    my_dataset = create_dataset(X_train, y_train)
    val_dataset = create_dataset(X_test, y_test)
    my_dataloader = DataLoader(my_dataset, batch_size=64, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=len(y_test), shuffle=True)
    my_rnn = SimpleBILSTM(len(letters), 128,2)
    my_rnn.to(device)
    optimizer = torch.optim.Adam(my_rnn.parameters(), lr=0.001)
    start_time = time.time()

    for epoch in range(epochs):
        my_rnn.train()
        total_loss = 0
        total_acc = 0
        total_sample = 0
        for i, (X,y) in enumerate(my_dataloader):
            hidden,c0 = my_rnn.init_hidden(batch_size=len(y))
            output, hidden,c = my_rnn(X, hidden,c0)
            total_sample += len(y)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            prediction = output.argmax(dim=1)
            acc_num = torch.sum(prediction == y).item()
            total_acc += acc_num
        loss_list.append(total_loss/total_sample)
        acc_list.append(total_acc/total_sample)


        my_rnn.eval()
        with torch.no_grad():
            for i, (X_val, y_val) in enumerate(val_dataloader):
                hidden, c0 = my_rnn.init_hidden(batch_size=len(y_val))
                output, hidden ,c= my_rnn(X_val, hidden,c0)
                loss = criterion(output, y_val)
                prediction = output.argmax(dim=1)
                acc_num = torch.sum(prediction == y_val).item()
                val_acc_list.append(acc_num/len(y_val))
                val_loss_list.append(loss.item()/len(y_val))
                print(f'epoch: {epoch+1}, train_loss: {total_loss/total_sample:.2f}, train_acc: {total_acc/total_sample:.2f}, val_loss: {loss.item()/len(y_val):.2f}, val_acc: {acc_num/len(y_val):.2f}, time: {time.time() - start_time : .2f}')

    torch.save(my_rnn.state_dict(), 'bilstm.pt')
    plt.plot(np.arange(1,11),loss_list,label = 'Training Loss')
    plt.plot(np.arange(1,11),val_loss_list,label = 'Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.xticks(np.arange(1,11))
    plt.title('Loss')
    plt.legend()
    plt.savefig('loss.png')
    plt.show()
    plt.plot(np.arange(1,11),acc_list,label = 'Training Accuracy')
    plt.plot(np.arange(1,11),val_acc_list,label = 'Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(1,11))
    plt.title('Accuracy')
    plt.legend()
    plt.savefig('accuracy.png')
    plt.show()

这里超参数设置为epochs:10,lr:1e-3,Adam优化器

epoch: 1, train_loss: 1.70, train_acc: 0.51, val_loss: 1.50, val_acc: 0.56, time:  11.83
epoch: 2, train_loss: 1.36, train_acc: 0.60, val_loss: 1.25, val_acc: 0.64, time:  22.84
epoch: 3, train_loss: 1.19, train_acc: 0.65, val_loss: 1.10, val_acc: 0.69, time:  33.76
epoch: 4, train_loss: 1.05, train_acc: 0.69, val_loss: 0.97, val_acc: 0.72, time:  44.63
epoch: 5, train_loss: 0.93, train_acc: 0.73, val_loss: 0.91, val_acc: 0.74, time:  55.49
epoch: 6, train_loss: 0.85, train_acc: 0.75, val_loss: 0.85, val_acc: 0.75, time:  66.38
epoch: 7, train_loss: 0.78, train_acc: 0.77, val_loss: 0.78, val_acc: 0.77, time:  77.38
epoch: 8, train_loss: 0.73, train_acc: 0.78, val_loss: 0.75, val_acc: 0.77, time:  88.27
epoch: 9, train_loss: 0.68, train_acc: 0.79, val_loss: 0.71, val_acc: 0.78, time:  99.44
epoch: 10, train_loss: 0.64, train_acc: 0.80, val_loss: 0.72, val_acc: 0.78, time:  110.43

完整代码的开源链接可以查询kaggle,gitee,github链接,其中gitee和github仓库中有训练好的模型权重,有需要可以在模型实例化后直接使用.

如需使用其他rnn序列模型如lstm和gru也可以直接实例化这里对应的模型类进行训练即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/914636.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

hive中windows子句的使用

概述 1,windows子句是对窗口的结果做更细粒度的划分 2、windows子句中有两种方式 rows :按照相邻的几行进行开窗 range:按照某个值的范围进行开窗 使用方式 (rows | range) between (UNBOUNDED | [num]) PRECEDING AND ([num] PRECEDING…

云渲染:服务器机房与物理机房两者有什么区别

云渲染选择服务器机房与物理机房两者主要区别在哪里呢? 服务器机房和物理机房作为云渲染的基础设施,各自扮演着不同的角色。 服务器机房的特点 服务器机房,通常指的是那些专门用于托管服务器的设施,它们可能位于云端&#xff0c…

零基础Java第十四期:继承与多态(二)

目录 一、继承 1.1. 继承的方式 1.2. final关键字 1.3. 继承与组合 1.4. protected关键字 二、多态 2.1. 多态的概念 2.2. 向上转型 2.3. 重写 2.4. 向下转型 2.5. 多态的优缺点 一、继承 1.1. 继承的方式 猫类可以继承动物类,中华田园猫类可以继承猫类…

电销系统:业绩翻倍的秘密武器

在当今竞争激烈的商业环境中,企业都在寻求各种方法来提升业绩。而电销系统正以其强大的功能和优势,成为众多企业实现业绩翻倍的有力工具。 一、高效的客户管理 电销系统能够对客户信息进行全面、系统的管理。从客户的基本资料、联系方式到历史沟通记录、…

CTFhub靶场RCE学习

靶场 eval执行 <?php if (isset($_REQUEST[cmd])) {eval($_REQUEST["cmd"]); } else {highlight_file(__FILE__); } ?> PHP代码显示&#xff0c;要求将命令赋值给cmd然后执行 先查看一下根目录文件 ?cmdsystem("ls");&#xff01;切记最后的分…

软件架构技术深入解析:AOP、系统安全架构、企业集成平台与微服务架构

目录 试题一 论面向方面的编程技术及其应用 解析 试题二 论系统安全架构设计及其应用 解析 试题三 论企业集成平台的理解与应用 解析 &#xff08;1&#xff09;通信服务 &#xff08;2&#xff09;信息集成服务 &#xff08;3&#xff09;应用集成服务 &#xff08;…

计算机网络基础:从IP地址到分层模型

计算机网络 1.计算机网络概述 概述 ​ 计算机网络是指两台或更多的计算机组成的网络&#xff0c;在同一个网络中&#xff0c;任意两台计算机都可以直接通信。互联网是网络的网络&#xff08;Internet&#xff09;&#xff0c;即把很多计算机网络连接起来&#xff0c;形成一个…

SpringCloud篇(服务提供者/消费者)(持续更新迭代)

在服务调用关系中&#xff0c;会有两个不同的角色&#xff1a; 服务提供者&#xff1a;一次业务中&#xff0c;被其它微服务调用的服务。&#xff08;提供接口给其它微服务&#xff09; 服务消费者&#xff1a;一次业务中&#xff0c;调用其它微服务的服务。&#xff08;调用…

【UML】类图及其六种关系,超详细介绍,细节满满

目录 一、概念 推荐一个画UML相关图的软件&#xff1a;ProcessOn 二、快速介绍类图 1、普通类的类图 2、抽象类的类图 3、接口的类图 三、类与类之间的关系 &#xff08;一&#xff09;、关联关系 1、单向关联 2、双向关联 3、自关联 &#xff08;二&#xff09;、聚…

HTML文件中引入jQuery的库文件

方法一&#xff1a; 1. 首先&#xff0c;在官方网站(https://jquery.com/)上下载最新版本的jQuery库文件&#xff0c;通常是一个名为jquery-x.x.x.min.js的文件。 2. 将下载的jquery-x.x.x.min.js文件保存到你的项目目录中的一个合适的文件夹中&#xff0c;比如将它保存在你的项…

使用Wireshark获取USB HID(Human Interface Device)报告描述符

使用Wireshark选择需要获取的USB进行抓取数据&#xff0c;找到设备&#xff08;host&#xff09;接收信息的数据 第二栏出现hid报告&#xff0c;右击选择复制流 将复制的内容粘贴到USB标准请求及描述符在线分析工具 - USB中文网 进行解析 以图中获取手写板的数据为例&#xff…

ReactPress与WordPress:一场内容管理系统的较量

ReactPress Github项目地址&#xff1a;https://github.com/fecommunity/reactpress WordPress官网&#xff1a;https://wordpress.org/ ReactPress与WordPress&#xff1a;一场内容管理系统的较量 在当今数字化时代&#xff0c;内容管理系统&#xff08;CMS&#xff09;已成为…

DevExpress WinForms中文教程:Data Grid - 如何绑定到实体框架数据源?

在本教程中&#xff0c;您将学习如何将DevExpress WinForms的网格控件绑定到实体框架数据源、如何使用数据注释属性来更改网格显示和管理数据的方式&#xff0c;以及如何将单元格值更改发送回数据源。 P.S&#xff1a;DevExpress WinForms拥有180组件和UI库&#xff0c;能为Wi…

使用多种机器学习调参模型进行二分类建模的全流程,代做分析辅导

使用多种机器学习调参模型进行二分类建模的全流程教程 机器学习全流程分析各个模块用到的总的参数文件 0. 分析参数文件 参数文件名称&#xff1a;total_analysis_params_demo.xlsx &#xff0c;很多分析模块都是这个总的参数文件&#xff0c;我的这个总的参数文件如果有更新…

材质(一)

描述&#xff1a; 材质蓝图&#xff0c;蓝图可以这么定义&#xff0c;是一种数据结构&#xff0c;是一种带有流水线的模糊的数据结构&#xff0c; 材质蓝图也是一种蓝图。 示例操作:

SCI论文数据可视化的在线网址

目录 SCI论文数据可视化的在线网址 EVenn(Evenn):免费 SCI论文数据可视化的在线网址 数据可视化的在线网址,以下是一些值得推荐的资源: ImageGP(ImageGP | ImageGP):该平台可以在线生成常见的线图、柱状图、散点图、箱线图、集合图、热图和直方图等。用户只需粘贴数…

电子应用产品设计方案-4:基于物联网和人工智能的温度控制器设计方案

一、概述 本温度控制器旨在提供高精度、智能化、远程可控的温度调节解决方案&#xff0c;适用于各种工业和民用场景。 二、系统组成 1. 传感器模块 - 采用高精度的数字式温度传感器&#xff0c;如 TMP117&#xff0c;能够提供精确到 0.01C 的温度测量。 - 配置多个传感器分布在…

5G的发展演进

5G发展的驱动力 什么是5G [远程会议&#xff0c;2020年7月10日] 在来自世界各地的政府主管部门、电信制造及运营企业、研究机构约200多名会议代表和专家们的共同见证下&#xff0c;ITU-R WP 5D#35e远程会议宣布3GPP 5G技术&#xff08;含NB-IoT&#xff09;满足IMT-2020 5G技…

人工智能--自然语言处理简介

上一篇&#xff1a;《人工智能模型训练中的数据之美——探索TFRecord》 序言&#xff1a;自然语言处理&#xff08;NLP&#xff09;是人工智能中的一种技术&#xff0c;专注于理解基于人类语言的内容。它包含了编程技术&#xff0c;用于创建可以理解语言、分类内容&#xff0c…

第8章 利用CSS制作导航菜单

8.1 水平顶部导航栏 水平莱单导航栏是网站设计中应用范围最广的导航设计&#xff0c;一般放置在页面的顶部。水平 导航适用性强&#xff0c;几乎所有类型的网站都可以使用&#xff0c;设计难度较低。 如果导航过于普通&#xff0c;无法容纳复杂的信息结构&#xff0c;就需要在…