数据挖掘实战-基于Catboost算法的艾滋病数据可视化与建模分析

 

🤵‍♂️ 个人主页:@艾派森的个人主页

✍🏻作者简介:Python学习者
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+


目录

1.项目背景

2.数据集介绍

3.技术工具

4.实验过程

4.1导入数据

4.2数据预处理

4.3数据可视化

4.4特征工程

4.5模型构建

源代码


1.项目背景

        艾滋病(Acquired Immunodeficiency Syndrome,AIDS)是一种由人类免疫缺陷病毒(Human Immunodeficiency Virus,HIV)引起的免疫系统功能受损的严重疾病。艾滋病的流行给全球卫生健康带来了严重挑战,特别是在一些发展中国家和弱势群体中。

        艾滋病的研究和管理需要综合多方面的信息,包括患者的个人特征、病毒的特性、医疗历史等。利用机器学习算法对艾滋病数据进行分析和建模,有助于更好地理解该疾病的传播规律、风险因素以及预测患者的病情发展。Catboost算法作为一种擅长处理类别型特征的梯度提升树算法,在艾滋病数据的分析与建模中具有一定的优势。

        本研究旨在利用Catboost算法对艾滋病数据进行分析与建模,并结合可视化技术,探索艾滋病患者的特征与疾病发展之间的关系。通过这一研究,可以为艾滋病的预防、诊断和治疗提供更加科学有效的支持和指导。

2.数据集介绍

本数据集来源于Kaggle,数据集包含有关被诊断患有艾滋病的患者的医疗保健统计数据和分类信息。该数据集最初于 1996 年发布。

属性信息:

time:失败或审查的时间

trt:治疗指标(0 = 仅 ZDV;1 = ZDV + ddI,2 = ZDV + Zal,3 = 仅 ddI)

age:基线年龄(岁)

wtkg:基线时的体重(公斤)

hemo:血友病(0=否,1=是)

homo:同性恋活动(0=否,1=是)

drugs:静脉注射药物使用史(0=否,1=是)

karnof:卡诺夫斯基分数(范围为 0-100)

oprior:175 年前非 ZDV 抗逆转录病毒治疗(0=否,1=是)

z30:175之前30天的ZDV(0=否,1=是)

preanti:抗逆转录病毒治疗前 175 天

race:种族(0=白人,1=非白人)

gender:性别(0=女,1=男)

str2:抗逆转录病毒史(0=未接触过,1=有经验)

strat:抗逆转录病毒病史分层(1='未接受过抗逆转录病毒治疗',2='> 1 但<= 52周既往抗逆转录病毒治疗',3='> 52周)

symptom:症状指标(0=无症状,1=症状)

treat:治疗指标(0=仅ZDV,1=其他)

offrtrt:96+/-5周之前off-trt的指标(0=否,1=是)

cd40:基线处的 CD4

cd420:20+/-5 周时的 CD4

cd80:基线处的 CD8

cd820:20+/-5 周时的 CD8

infected:感染艾滋病(0=否,1=是)

3.技术工具

Python版本:3.9

代码编辑器:jupyter notebook

4.实验过程

4.1导入数据

首先导入本次实验用到的第三方库并加载数据集

查看数据大小

查看数据基本信息

查看数据描述性统计 

4.2数据预处理

统计数据缺失值情况

可以发现原始数据集并不存在缺失值,故不需要处理

统计重复值情况

可以发现原始数据集并存在重复值,故不需要处理

4.3数据可视化

为了方便后面作图,这里我们自定义一个画图函数

def mPlotter(r, c, size, _targets, text):
    
    bg = '#010108'
    
    palette = ['#df5337', '#d24644', '#f7d340', '#3339FF', '#440a68', '#84206b', '#f1ef75', '#fbbe23', '#400a67']
    
    font = 'ubuntu'
    
    fig = plt.figure(figsize=size)
    
    fig.patch.set_facecolor(bg)
    
    grid = fig.add_gridspec(r, c)
    
    grid.update(wspace=0.5, hspace=0.25)
    
    __empty_diff = ((r * c) - 1) - len(_targets)
        
    axes = []
    
    for i in range(r):
        for j in range(c):
            axes.append(fig.add_subplot(grid[i, j]))
    
    for idx, ax in enumerate(axes):
        ax.set_facecolor(bg) 
        
        if idx == 0:
            ax.spines["bottom"].set_visible(False)
            ax.tick_params(left=False, bottom=False)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.text(0.5, 0.5,
                 f'{text}',
                 horizontalalignment='center',
                 verticalalignment='center',
                 fontsize=18, 
                 fontweight='bold',
                 fontfamily=font,
                 color="#fff")
        else:
            if (idx - 1) < len(_targets):
                ax.set_title(_targets[idx - 1].capitalize(), fontsize=14, fontweight='bold', fontfamily=font, color="#fff")
                ax.grid(color='#fff', linestyle=':', axis='y', zorder=0,  dashes=(1,5))
                ax.set_xlabel("")
                ax.set_ylabel("")
            else:
                ax.spines["bottom"].set_visible(False)
                ax.tick_params(left=False, bottom=False)
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                
        ax.spines["left"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        
    def cb(ax):
        ax.set_xlabel("")
        ax.set_ylabel("")
        
    if __empty_diff > 0:
        axes = axes[:-1*__empty_diff]
        
    return axes, palette, cb

开始作图 

4.4特征工程

拆分数据集为训练集和测试集

平衡数据集

数据标准化处理

4.5模型构建

首先找到catboost的最佳超参数!

使用超参数构建并训练模型,打印模型的准确率和分类报告 

将混淆矩阵可视化

最后再作出ROC曲线

源代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings,random,optuna
import plotly.express as px
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score,auc,roc_curve
from sklearn.preprocessing import MinMaxScaler
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier

plt.style.use('dark_background')
warnings.simplefilter('ignore', category=FutureWarning)
ds = pd.read_csv('AIDS_Classification.csv')
ds.head()
ds.shape
ds.info()
ds.describe(percentiles=[0, .25, .30, .50, .75, .80, 1]).T.style.background_gradient(cmap = 'inferno')
ds.isnull().sum()
ds.duplicated().sum()
def mPlotter(r, c, size, _targets, text):
    
    bg = '#010108'
    
    palette = ['#df5337', '#d24644', '#f7d340', '#3339FF', '#440a68', '#84206b', '#f1ef75', '#fbbe23', '#400a67']
    
    font = 'ubuntu'
    
    fig = plt.figure(figsize=size)
    
    fig.patch.set_facecolor(bg)
    
    grid = fig.add_gridspec(r, c)
    
    grid.update(wspace=0.5, hspace=0.25)
    
    __empty_diff = ((r * c) - 1) - len(_targets)
        
    axes = []
    
    for i in range(r):
        for j in range(c):
            axes.append(fig.add_subplot(grid[i, j]))
    
    for idx, ax in enumerate(axes):
        ax.set_facecolor(bg) 
        
        if idx == 0:
            ax.spines["bottom"].set_visible(False)
            ax.tick_params(left=False, bottom=False)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.text(0.5, 0.5,
                 f'{text}',
                 horizontalalignment='center',
                 verticalalignment='center',
                 fontsize=18, 
                 fontweight='bold',
                 fontfamily=font,
                 color="#fff")
        else:
            if (idx - 1) < len(_targets):
                ax.set_title(_targets[idx - 1].capitalize(), fontsize=14, fontweight='bold', fontfamily=font, color="#fff")
                ax.grid(color='#fff', linestyle=':', axis='y', zorder=0,  dashes=(1,5))
                ax.set_xlabel("")
                ax.set_ylabel("")
            else:
                ax.spines["bottom"].set_visible(False)
                ax.tick_params(left=False, bottom=False)
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                
        ax.spines["left"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        
    def cb(ax):
        ax.set_xlabel("")
        ax.set_ylabel("")
        
    if __empty_diff > 0:
        axes = axes[:-1*__empty_diff]
        
    return axes, palette, cb
target = 'infected'
cont_cols = ['time', 'age', 'wtkg', 'preanti', 'cd40', 'cd420', 'cd80', 'cd820']
dis_cols = list(set(ds.columns) - set([*cont_cols, target]))
len(cont_cols), len(dis_cols)
axes, palette, cb = mPlotter(1, 2, (20, 5), [target], 'Count Of\nInfected Variable\n______________')
sns.countplot(x=ds[target], ax = axes[1], color=palette[0])
cb(axes[1])
axes, palette, cb = mPlotter(3, 3, (20, 20), cont_cols, 'KDE Plot of\nContinuous Variables\n________________')
for col, ax in zip(cont_cols, axes[1:]):
    sns.kdeplot(data=ds, x=col, ax=ax, hue=target, palette=palette[1:3], alpha=.5, linewidth=0, fill=True)
    cb(ax)
axes, palette, cb = mPlotter(3, 3, (20, 20), cont_cols, 'Boxen Plot of\nContinuous Variables\n________________')
for col, ax in zip(cont_cols, axes[1:]):
    sns.boxenplot(data=ds, y=col, ax=ax, palette=[palette[random.randint(0, len(palette)-1)]])
    cb(ax)
axes, palette, cb = mPlotter(5, 3, (20, 20), dis_cols, 'Countplot of\nDiscrete Variables\n________________')
for col, ax in zip(dis_cols, axes[1:]):
    sns.countplot(x=ds[col], ax = ax, hue=ds[target], palette=palette[6:8])
    cb(ax)
ax = px.scatter_3d(ds, x="age", y="wtkg", z="time", template= "plotly_dark", color="infected")
ax.show()
ax = px.scatter_3d(ds, x="preanti", y="cd40", z="cd420", template= "plotly_dark", color="infected")
ax.show()
ax = px.scatter_3d(ds, x="preanti", y="cd80", z="cd820", template= "plotly_dark", color="infected")
ax.show()
fig = plt.figure(figsize=(25, 8))
gs = fig.add_gridspec(1, 1)
gs.update(wspace=0.3, hspace=0.15)
ax = fig.add_subplot(gs[0, 0])
ax.set_title("Correlation Matrix", fontsize=28, fontweight='bold', fontfamily='serif', color="#fff")
sns.heatmap(ds[cont_cols].corr().transpose(), mask=np.triu(np.ones_like(ds[cont_cols].corr().transpose())), fmt=".1f", annot=True, cmap='Blues')
plt.show()
# 拆分数据集为训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(ds.iloc[:,:-1], ds.iloc[:, -1], random_state=3, train_size=.7)
x_train.shape, y_train.shape, x_test.shape, y_test.shape
# 平衡数据集
smote = SMOTE(random_state = 14)
x_train, y_train = smote.fit_resample(x_train, y_train)
x_train.shape, y_train.shape, x_test.shape, y_test.shape
# 数据标准化处理
x_train = MinMaxScaler().fit_transform(x_train)
x_test = MinMaxScaler().fit_transform(x_test)
# 找到catboost的最佳超参数!
def objective(trial):
    params = {
        'iterations': trial.suggest_int('iterations', 100, 1000),
        'learning_rate': trial.suggest_loguniform('learning_rate', 0.01, 0.5),
        'depth': trial.suggest_int('depth', 1, 12),
        'l2_leaf_reg': trial.suggest_loguniform('l2_leaf_reg', 1e-3, 10.0),
        'border_count': trial.suggest_int('border_count', 1, 255),
        'thread_count': -1,
        'loss_function': 'MultiClass',
        'eval_metric': 'Accuracy',
        'verbose': False
    }
    
    model = CatBoostClassifier(**params)
    model.fit(x_train, y_train, eval_set=(x_test, y_test), verbose=False, early_stopping_rounds=20)
    y_pred = model.predict(x_test)
    accuracy = accuracy_score(y_test, y_pred)
    return accuracy

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50, show_progress_bar=True)
# 初始化模型并使用前面的最佳超参数
model = CatBoostClassifier(
    verbose=0, 
    random_state=3,
    **study.best_params
)
# 训练模型
model.fit(x_train, y_train)
# 预测
y_pred = model.predict(x_test)
# 打印模型评估指标
print('模型准确率:',accuracy_score(y_test,y_pred))
print (classification_report(y_pred, y_test))
plt.subplots(figsize=(20, 6))
sns.heatmap(confusion_matrix(y_pred, y_test), annot = True, fmt="d", cmap="Blues", linewidths=.5)
plt.show()
# 画出ROC曲线
y_prob = model.predict_proba(x_test)[:,1]
false_positive_rate, true_positive_rate, thresholds = roc_curve(y_test, y_prob) 
roc = auc(false_positive_rate, true_positive_rate)
plt.title('ROC')
plt.plot(false_positive_rate,true_positive_rate, color='red',label = 'AUC = %0.2f' % roc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],linestyle='--')
plt.axis('tight')
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()
# 模型预测
res = pd.DataFrame()
res['真实值'] = y_test
res['预测值'] = y_pred
res.sample(10)

资料获取,更多粉丝福利,关注下方公众号获取

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/688367.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

WindowServer2022配置iSCSI磁盘CHAP认证

实验环境: WindowsServer2022 一、实验环境 1、实验拓扑 为存储服务器添加4块100G的磁盘新建1个100G的iSCSI虚拟磁盘&#xff0c;发起目标选择IQN为保证连接安全&#xff0c;采用单向CHAP认证新建1个200G的iSCSI虚拟磁盘&#xff0c;发起目标选择IP地址为保证连接安全&#x…

DeepDriving | 多目标跟踪算法之SORT

本文来源公众号“DeepDriving”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;多目标跟踪算法之SORT 1 简介 SORT是2016年发表的一篇文章《Simple Online and Realtime Tracking》中提出的一个经典的多目标跟踪算法&#xff0c;…

【数据结构】栈和队列-->理解和实现(赋源码)

Toc 欢迎光临我的Blog&#xff0c;喜欢就点歌关注吧♥ 前面介绍了顺序表、单链表、双向循环链表&#xff0c;基本上已经结束了链表的讲解&#xff0c;今天谈一下栈、队列。可以简单的说是前面学习的一特殊化实现&#xff0c;但是总体是相似的。 前言 栈是一种特殊的线性表&…

flask_sqlalchemy时间缓存导致datetime.now()时间不变问题

问题是这样的&#xff0c;项目在本地没什么问题&#xff0c;但是部署到服务器过一阵子发现&#xff0c;这个时间会在某一刻定死不变。 重启uwsgi后&#xff0c;发现第一条数据更新到了目前最新时间&#xff0c;过了一会儿再次发送也变了时间&#xff0c;但是再过几分钟再发就会…

【全开源】JAVA打车小程序APP打车顺风车滴滴车跑腿源码微信小程序打车源码

&#xff1a;构建便捷出行新体验 一、引言&#xff1a;探索打车系统小程序源码的重要性 在数字化快速发展的今天&#xff0c;打车系统小程序已成为我们日常生活中不可或缺的一部分。它以其便捷、高效的特点&#xff0c;极大地改变了我们的出行方式。而背后的关键&#xff0c;…

啵啵啵啵啵啵啵啵啵啵啵啵啵啵啵

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和技术。关…

swaggerHole:针对swaggerHub的公共API安全扫描工具

关于swaggerHole swaggerHole是一款针对swaggerHub的API安全扫描工具&#xff0c;该工具基于纯Python 3开发&#xff0c;可以帮助广大研究人员检索swaggerHub上公共API的相关敏感信息&#xff0c;整个任务过程均以自动化形式实现&#xff0c;且具备多线程特性和管道模式。 工具…

增加强制索引依然慢

版本: 阿里云RDS MySQL 8.0.25 线上数据库CPU达到100%, 定位到如下SQL EXPLAIN SELECT ssd.goods_no,ssd.goods_name,ssd.goods_spec,ssd.goods_unit,ssd.create_time,w.warehouse_name,sb.batch_no,swl.warehouse_region_location_name,sc.customer_name AS goodsOwnerName,s…

如何在MySQL中实现upsert:如果不存在则插入?

目录 1 使用 REPLACE 2 使用 INSERT ... ON DUPLICATE KEY UPDATE 使用 INSERT IGNORE 有效会导致 MySQL 在尝试执行语句时忽略执行错误 INSERT 。这意味着 包含 索引或 字段 INSERT IGNORE 中重复值的语句 不会 产生错误&#xff0c;而只是完全忽略该特定 命令。其明显目的是…

2048小游戏的菜鸡实现方法

# 2048小游戏的实现与分析 2048是一款非常受欢迎的数字滑块游戏&#xff0c;其目标是通过滑动和合并相同数字的方块来创建一个值为2048的方块。下面&#xff0c;我们将通过分析一个C语言实现的2048小游戏的源代码&#xff0c;来探索如何用编程实现这款游戏。 ## 游戏概述 20…

指针(初阶1)

一.指针是什么 通俗的讲&#xff0c;指针就是地址&#xff0c;其存在的意义就像宾馆房间的序号一样是为了更好的管理空间。 如下图&#xff1a; 如上图所示&#xff0c;指针就是指向内存中的一块空间&#xff0c;也就相当于地址 二.一个指针的大小是多少 之前我们学习过&#x…

Springboot注意点

1.Usermapper里加param注解 2.RequestParam 和 RequestBody的区别&#xff1a; RequestParam 和 RequestBody的区别&#xff1a; RequestParam 和 RequestBody 是Spring框架中用于处理HTTP请求的两个不同的注 get请求一般用url传参数&#xff0c;所以参数名和参数的值就在ur…

LCM — Least Common Multiple 最小公倍数

因为任何一个数都可以表示为若干个质数幂的乘积。 比如75 3*5*5&#xff0c;即 2^0 * 3^1 * 5^2 * 7^0 ... 那么对于两个数来说&#xff0c;gcd就是他们取每个质数的较小幂的乘积&#xff0c;lcm则相反。显然&#xff0c;这些幂加起来就是他们乘积。 gcd(a,b) * lcm(a,b) a…

立创·天空星开发板-GD32F407VE-USART

本文以 立创天空星开发板-GD32F407VET6-青春版 作为学习的板子&#xff0c;记录学习笔记。 立创天空星开发板-GD32F407VE-USART 基础通信概念同步通信 & 异步通信串行通信 & 并行通信双工 & 单工通讯速率码元 串口通信数据帧 串口封装 基础通信概念 通信协议是网络…

本地运行ChatTTS

TTS 是将文字转为语音的模型&#xff0c;最近很火的开源 TTS 项目&#xff0c;本地可以运行&#xff0c;运行环境 M2 Max&#xff0c;差不多每秒钟 4&#xff5e;&#xff5e;5 个字。本文将介绍如何在本地运行 ChatTTS。 下载源码 首先下载源代码 git clone https://github…

WPF中读取Excel文件的内容

演示效果 实现方案 1.首先导入需要的Dll(这部分可能需要你自己搜一下) Epplus.dll Excel.dll ICSharpCode.SharpZipLib.dll 2.在你的解决方案的的依赖项->添加引用->浏览->选择1中的这几个Dll点击确定。(添加依赖) 3.然后看代码内容 附上源码 using Excel; usi…

TypeScript环境安装与VScode编辑器的使用

说明大背景环境&#xff0c;我用的是window10系统。 1.安装node.js 。 去官网下载安装包。 虽然我去的是官网&#xff0c;但是不知为何下载了个不知名的东西&#xff0c;后来又找了个链接才下载正确了。 实际上就是一个.msi的文件。我用的版本&#xff1a;node-v18.19.0-x6…

【第四节】C/C++数据结构之树与二叉树

目录 一、基本概念与术语 二、树的ADT 三、二叉树的定义和术语 四、平衡二叉树 4.1 解释 4.2 相关经典操作 4.3 代码展示 一、基本概念与术语 树(Tree)是由一个或多个结点组成的有限集合T。其中: 1 有一个特定的结点&#xff0c;称为该树的根(root)结点&#xff1b; 2 …

GPT-4o:突出优势 和 应用场景

还是大剑师兰特&#xff1a;曾是美国某知名大学计算机专业研究生&#xff0c;现为航空航海领域高级前端工程师&#xff1b;CSDN知名博主&#xff0c;GIS领域优质创作者&#xff0c;深耕openlayers、leaflet、mapbox、cesium&#xff0c;canvas&#xff0c;webgl&#xff0c;ech…

centos官方yum源不可用 解决方案(随手记)

昨天用yum安装软件的时候&#xff0c;就报错了 [rootop01 ~]# yum install -y net-tools CentOS Stream 8 - AppStream 73 B/s | 38 B 00:00 Error: Failed to download metadata for repo appstream: Cannot prepare internal mirrorlis…