【课程总结】Day6(下):机器学习项目实战–成人收入预测

机器学习项目实战:成人收入预测

项目目的

基于个人收入数据(包括教育程度、年龄、性别等)的数据集,通过机器学习算法,预测一个人的年收入是否超过5万美金。

数据集

  • 地址:http://idatascience.cn/dataset-detail?table_id=100368

  • 数据集字段

    字段名称字段类型字段说明
    age数值型年龄
    workclass字符型工作类型
    fnlwgt字符型序号
    education字符型教育程度
    education-num数值型受教育时长
    marital-status字符型婚姻状况
    occupation字符型职业
    relationship字符型关系
    race字符型种族
    sex字符型性别
    capital-gain数值型资本收益
    capital-loss数值型资本损失
    hours-per-week数值型每周工作小时数
    native-country字符型原籍
    salary字符型收入
  • 数据集样例:
    file

解决思路

分析输入/输出

通过分析,本次项目我们要解决的问题:给定一个人的相关信息(包括年龄、教育程度、受教育时长等),预测其收入是否超过5万。

该问题如果是预测其收入是多少,那么就属于线性回归问题;但数据集中在收入标签是>=5万或<5万,所以这应该是一个分类问题。

分析相关输入/输出如下:

  • 输入:一个人的信息(包括年龄、教育程度、受教育时长等)
  • 输出:0为<=50K,1为>50K
构建数据集
分析数据类型

首先,我们先分析一下如何将数据集向量化,通过以下代码来查看数据集每列的内容:

import csv
import numpy as np
def read_file(file_path, skip_header=True):
  """
  读取CSV文件的内容。
  参数:
      file_path (str): CSV文件的路径。
      skip_header (bool): 是否跳过表头数据,默认为True。
  返回:
      list: 包含CSV文件内容的列表。
  """
  print(f'读取原始数据集文件: {file_path}')
  with open(file_path, 'r', encoding='utf-8') as f:
      if skip_header:
          # 跳过表头数据
          f.readline()
      reader = csv.reader(f)
      return [row for row in reader]  # 读取csv文件的内容

    
def print_unique_columns(data):
    """
    打印表格数据中每一列去重后的数据。
    参数:
        data (list): 包含表格数据的列表。
    """
    if not data:
        print("没有读取到任何数据。")
        return

    # 获取每一列的数据
    columns = zip(*data)

    # 打印每一列去重后的数据
    for i, column in enumerate(columns):
        unique_values = set(column)
        unique_count = len(unique_values)
        print(f"第{i}列去重后的数据(最多前20个):")
        
        # 如果数据大于20个,则只打印前20个
        if unique_count > 20:
            for value in list(unique_values)[:20]:
                print(value)
            print(f"共 {unique_count} 个不同的值")
        else:
            for value in unique_values:
                print(value)
            print(f"共 {unique_count} 个不同的值")
        print()
        
data = read_file('./成人收入预测数据集.csv')
print_unique_columns(data)

运行结果:

对上述每一列内容进行梳理,梳理结果如下:

列的序号(第几列)原始数据表头名称表头名称中文解释对应列取值举例取值个数数据类型数据处理办法
0age年龄69, 62, 37, 80, 3473连续量保留
1workclass工作类型Without-pay, State-gov, Federal-gov9离散量保留
2fnlwgt疑似邮编编号226092, 209770, 3184021648连续量舍弃
3education教育程度Masters, Bachelors, Some-college16离散量保留
4education-num教育年限3, 4, 16, 2, 1116连续量保留
5marital-status婚姻状况Divorced, Separated, Never-married7离散量保留
6occupation职业类型Exec-managerial, Adm-clerical, Handlers-cleaners15离散量保留
7relationship家庭关系Wife, Not-in-family, Other-relative6离散量保留
8race种族Amer-Indian-Eskimo, White, Black5离散量保留
9sex性别Female, Male2离散量保留
10capital-gain投资利得3781, 2062, 401119连续量保留
11capital-loss投资损失1816, 1876, 176292连续量保留
12hours-per-week每周工作时长37, 80, 3494连续量保留
13native-country出生国家Japan, England, Guatemala42离散量保留

通过分析上述每列内容可知:

  • 数据集中,age、education-num、capital-gain、capital-loss、hours-per-week等字段的内容一般都是数字,可以视为连续量。这类数据直接使用即可,不需要做向量化处理(因为是数字,机器能够处理)。
  • 数据集中,workclass、education、marital-status、occupation、relationship、race、sex、native-country等字段内容一般都是表示状态,属于离散量;由于其内容不是数字,为了让机器能够处理,我们后续需要进行向量化处理。
  • 数据集中,fnlwgt疑似是连续量,但是查看取值个数(21648)与总样本个数(32561个)不一致,所以确定该列不是类似身份证号的唯一编号数据(如果身份证可以舍弃,因为身份证跟收入没啥关系);该列可能是类似邮编的数据,虽然邮编可能与收入有一定关系,但是因为其取值数量比较大(21648个),这会导致特征数据过去庞大,所以选择舍弃。
离散量的编码

通过与chatGPT交流,我们了解到,离散量的编码常见方式有One-Hot编码标签编码序数编码Target编码哈希编码实体嵌入

  1. One-Hot编码:
    • 编码方式:将每个离散特征转换为多个二进制特征。
    • 优点:能够很好地表示分类特征,不会引入大小关系。
    • 缺点:会增加特征维度。
    • 适用场景:适用于没有大小关系的分类特征,如性别、产品类别等。
  2. 标签编码(Label Encoding):
    • 编码方式:将每个离散特征的取值映射为一个整数值。
    • 优点:简单直观。
    • 缺点:可能会让算法认为特征之间存在大小关系。
    • 适用场景:适用于没有明确大小关系的分类特征,但要注意可能带来的影响。
  3. 序数编码(Ordinal Encoding):
    • 编码方式:将每个离散特征的取值映射为一个有序的整数值。
    • 优点:能够保留特征之间的大小关系。
    • 缺点:对于没有明确大小关系的离散特征可能不太合适。
    • 适用场景:适用于有明确大小关系的有序离散特征,如学历、星级等。
  4. Target编码:
    • 编码方式:将每个离散特征的取值映射为目标变量的平均值或中位数。
    • 优点:能够捕捉特征取值与目标变量之间的关系,通常能提高模型性能。
    • 缺点:需要提前知道目标变量,不适用于无监督学习。
    • 适用场景:适用于有监督学习问题中,当目标变量与离散特征存在相关性时。
  5. 哈希编码:
    • 编码方式:将每个离散特征的取值通过哈希函数映射为一个整数值。
    • 优点:在处理高基数特征时较为有效,能减少内存占用。
    • 缺点:可能会产生冲突,导致信息丢失。
    • 适用场景:适用于高基数离散特征,且内存受限的情况下。
  6. 实体嵌入(Entity Embedding):
    • 编码方式:将每个离散特征的取值映射为一个低维的稠密向量。
    • 优点:能够学习特征之间的潜在关系,通常能提高模型性能。
    • 缺点:需要额外的训练过程来学习嵌入向量。
    • 适用场景:适用于复杂的机器学习问题,如自然语言处理、推荐系统等,能够更好地捕捉特征之间的潜在关系。

分析上述编码内容,其中

  • Target编码哈希编码实体嵌入暂未学习到,本次暂不做考虑。

  • 标签编码序数编码由于其存在潜在的数字大小对比,让算法认为特征之间存在大小关系,所以并不适用于上述的婚姻状况、职业类型、国家等。

所以,综上所述本次练习使用One-Hot编码,其编码示意图如下:

基于上面的思想,我们通过与GPT沟通了解到,sklearn的库函数中有OneHotEncoder,使用方法如下:

import numpy as np
from sklearn.preprocessing import OneHotEncoder

# 示例数据
data = np.array([
    ['A', 'X', 'P'],
    ['B', 'Y', 'Q'],
    ['A', 'X', 'R'],
    ['B', 'Z', 'Q']
])

# 创建 OneHotEncoder 实例
encoder = OneHotEncoder(sparse_output=False)

# 对数据进行one-hot编码
encoded_data = encoder.fit_transform(data)

print("原始数据:")
print(data)
print("\n编码后的数据:")
print(encoded_data)

执行结果如下:

分离特征和标签
def split_data(data):
    """
    将数据分割为标签和数据。
    参数:
        data (list): 数据行的列表,第一个元素是标签。
    返回:
        numpy.ndarray: 标签数组。
        numpy.ndarray: 连接元素后的数据数组。
    """

    # 去除每个元素的前后空格
    data = [[col.strip() for col in row] for row in data]

    # 分离数据和标签
    n_label = np.array([row[-1] for row in data])
    n_data = np.array([row[:-1] for row in data])

    return n_label, n_data


csv_data = read_file('./成人收入预测数据集.csv')
label, data = split_data(csv_data)
label, data

运行结果:

处理特征列

这部分的处理较为麻烦,整体思路是这样:

1、实现一个函数,传入三个参数:分别是离散量列的序号、连续量列序号和丢弃列序号

2、函数根据传入的列序号,分别进行如下处理:

  • 如果是连续量列,取出对应的列,不用做处理;

  • 如果是离散量列,取出对应的列,使用OneHotEncoder进行编码

  • 如果是丢弃列,则在矩阵中删除对应的列

    最后,在去除丢弃列之后,将连续量列和离散量列按照列方向堆叠为一个新的矩阵

import numpy as np
from sklearn.preprocessing import OneHotEncoder, StandardScaler

def vectorize_data_with_sklearn(data, onehot_cols, continuous_cols, exclude_cols=None):
    """
    使用scikit-learn将给定的NumPy数组中的数据进行one-hot编码和标准化处理。

    参数:
    data (np.ndarray): 输入的NumPy数组
    onehot_cols (list): 需要进行one-hot编码的列索引
    continuous_cols (list): 不需要one-hot编码的连续量列索引
    exclude_cols (list, optional): 需要排除的列索引

    返回:
    np.ndarray: 经过one-hot编码和标准化的向量化数据
    """
    # 排除不需要处理的列
    if exclude_cols:
        data = np.delete(data, exclude_cols, axis=1)
        onehot_cols = [col - sum(col > exc for exc in exclude_cols) for col in onehot_cols if col not in exclude_cols]
        # 解释过程:
        # 1. 遍历 onehot_cols 中的每个索引 col
        # 2. 检查 col 是否在 exclude_cols 中
        #    - 如果在,计算 col 在 exclude_cols 中的位置 sum(col > exc for exc in exclude_cols)
        #    - 并从 col 中减去这个值,更新 col 的索引
        #    - 例如: col = 1, 在 exclude_cols 中的位置为 0, 则更新后 col = 1 - 0 = 1
        #    - ⭐ 这样可以确保 onehot_cols 中的索引能正确对应到数据的列
        # 3. 如果 col 不在 exclude_cols 中,则保留原始索引

        continuous_cols = [col - sum(col > exc for exc in exclude_cols) for col in continuous_cols if col not in exclude_cols]
        # 解释过程:
        # 1. 遍历 continuous_cols 中的每个索引 col
        # 2. 检查 col 是否在 exclude_cols 中
        #    - 如果在,计算 col 在 exclude_cols 中的位置 sum(col > exc for exc in exclude_cols)
        #    - 并从 col 中减去这个值,更新 col 的索引
        #    - 例如: col = 2, 在 exclude_cols 中的位置为 0, 则更新后 col = 2 - 0 = 2
        #    - ⭐ 这样可以确保 continuous_cols 中的索引能正确对应到数据的列
        # 3. 如果 col 不在 exclude_cols 中,则保留原始索引

    else:
        onehot_cols = onehot_cols[:]
        continuous_cols = continuous_cols[:]

    # 对离散量列进行one-hot编码
    onehot_encoder = OneHotEncoder(sparse_output=False)
    one_hot_data = onehot_encoder.fit_transform(data[:, onehot_cols])

    # 对连续量列进行标准化
    scaler = StandardScaler()
    continuous_data = scaler.fit_transform(data[:, continuous_cols])

    # 将one-hot编码结果和标准化后的连续量列拼接起来
    final_data = np.hstack((one_hot_data, continuous_data))

    return final_data
  
  
csv_data = read_file('./成人收入预测数据集.csv')
label, data = split_data(csv_data)

# 对数据进行切分
onehot_cols = [1, 3, 5, 6, 7, 8, 9, 13]
continuous_cols = [0, 2, 4, 10, 11, 12]

# 排除列增加最后一列
exclude_cols = [2, 13]
vectorized_data = vectorize_data_with_sklearn(data, onehot_cols, continuous_cols, exclude_cols)
vectorized_data

运行结果:

处理标签列

因为标签列的内容只有两种情况:‘<=50K’ 和’>50K’,所以只需要将这一列中’<=50K’替换为0,'>50K’替换为1即可。

import numpy as np
from sklearn.preprocessing import LabelBinarizer

def binarize_labels(labels):
    """
    将标签二值化。
    参数:
        labels (numpy.ndarray): 原始标签数组。
    返回:
        numpy.ndarray: 二值化后的标签数组。
    """
    lb = LabelBinarizer()
    binarized_labels = lb.fit_transform(labels)
    return binarized_labels

vectorized_label = binarize_labels(label)
vectorized_label
试验算法

为了能够将整个流程跑通,我们仍然选择决策树算法跑通流程。


from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

def decision_tree(label, data, test_size=0.2):
    """
    决策树模型的训练和评估。
    参数:
        train_data_file_path (str): 向量化数据的文件路径。
        test_size (float): 测试集的比例,默认为0.2。
    """
    print('开始加载训练数据...')

    # 训练集和测试集切分
    X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=test_size)

    print('开始训练决策树模型...')
    # 数据预测
    clf = DecisionTreeClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    # 评估
    print('开始决策树预测...')
    accuracy = np.mean(y_pred == y_test)
    print(f'预测准确率:{accuracy}')

# 读取文件
listdata = read_file('./成人收入预测数据集.csv')
# 对数据进行切分
label, data = split_data(listdata)

# 对数据进行切分
onehot_cols = [1, 3, 5, 6, 7, 8, 9, 13]
continuous_cols = [0, 2, 4, 10, 11, 12]

# 排除列增加最后一列
exclude_cols = [2, 13]    
vectorized_data = vectorize_data_with_sklearn(data, onehot_cols, continuous_cols, exclude_cols)
vectorized_label = binarize_labels(label)

decision_tree(vectorized_label, vectorized_data)

运行结果:

工程优化

通过以上的工作,整体构建数据集→训练模型→预测模型流程已经跑通,接下来进行代码重构优化。

1、将整体代码使用面向对象封装为类实现

2、在模型预测部分加入KNN、贝叶斯、线性回归、随机森林、SVC向量机的方式

3、将预测结果使用matplotlib绘制出来

4、给关键代码处增加带有时间戳的日志

以上工作就交给GPT来完成了,最后重构的代码请见Github仓库

遴选算法

通过运行上述工程优化后的代码,执行结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/700227.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HJS-DE1/3断电延时时间继电器 导轨安装 约瑟JOSEF

​HJS-DE断电延时时间继电器 HJS-DE1/3断电延时时间继电器 HJS-DE1/2断电延时时间继电器 HJS-DE断电延时时间继电器 用途 时间继电器HJS-DE1/3作为通电立即动作断电延时返回的元件&#xff0c;用于交流或直流保护和自动控制装置中&#xff0c;使被控元件得到所需延时。本继电…

MySQL系列-语法说明以及基本操作(一)

1、前言 主要讲解MySQL的基本语法 官网文档 https://docs.oracle.com/en-us/iaas/mysql-database/doc/getting-started.html 关于MySQL的基本语法&#xff0c;关于数据类型、表的操作、数据操作、事务、备份等&#xff0c;可参考 http://www.voidme.com/mysql 2、数据类型 数…

【PX4-AutoPilot教程-TIPS】离线安装Flight Review PX4日志分析工具

离线安装Flight Review PX4日志分析工具 安装方法 安装方法 使用Flight Review在线分析日志&#xff0c;有时会因为网络原因无法使用。 使用离线安装的方式使用Flight Review&#xff0c;可以在无需网络的情况下使用Flight Review网页。 安装环境依赖。 sudo apt-get insta…

在客户自己的smarteye server上如何测试使用公网对讲PoC集群通话功能

创建用户对讲群组&#xff0c;可以在WEB调度台上操作&#xff0c;也可以在执法记录仪或者手机上操作&#xff0c;当然&#xff0c;前提是当前登录的用户有创建集群通话对讲组的权限&#xff0c;一般的推荐用管理员帐号去创建这些对讲组&#xff0c;这样最简单。 总体说明 用管…

echarts的toolbox自定义feature标签及事件

1. 需求 在使用echarts图时希望toolbox扩展一些自定义icon和点击事件&#xff0c;而不只是图中这些echarts提供的事件。 2. 文档 属性名类型描述toolbox.featureObject各工具配置项。 feature中除了echarts提供的各个内置的工具按钮外&#xff0c;可以自定义工具按钮。 除…

经典文献阅读之--FlashOcc(快速且内存高效的占用预测模块)

Tip: 如果你在进行深度学习、自动驾驶、模型推理、微调或AI绘画出图等任务&#xff0c;并且需要GPU资源&#xff0c;可以考虑使用UCloud云计算旗下的Compshare的GPU算力云平台。他们提供高性价比的4090 GPU&#xff0c;按时收费每卡2.6元&#xff0c;月卡只需要1.7元每小时&…

windows下编译linux的go程序报错gcc_linux_amd64.c: In function ‘_cgo_sys_thread_start‘:

报错解决 执行报错 报错解决 其他补充 windows下编译linux 64位程序 set GOARCHamd64 set GOOSlinux set CGO_ENABLED0 go build手把手教你实现Golang跨平台编译 报错解决其他补充windows下编译linux 64位程序Windows下编译Mac平台64位可执行程序&#xff1a;Linux 下编译…

webpack--压缩,代码的拆分,tree shinking

Terser 对代码进行压缩、丑化 const TerserPlugin require("terser-webpack-plugin");optimization: { minimize: true, //在开发环境下启用 CSS 优化minimizer: [new TerserPlugin({extractComments: false,terserOptions: {compress: {arguments: true,// 将函数…

Cursor是什么?基于ChatGPT代码编辑器的cursor如何使用?VS Code如何迁移到Cursor的步骤

Cursor是什么 Cursor 是一个基于 Visual Studio Code&#xff08;VS Code&#xff09;技术构建的高级代码编辑器&#xff0c;专为提高编程效率并更深度地整合 AI 功能而设计。它不仅继承了 VS Code 的强大功能和用户界面&#xff0c;还增加了专门针对 AI 支持的特色功能。 Cur…

UE5 渲染性能优化 学习笔记

主要考虑三个点&#xff1a; 1、灯光 2、半透明物体 3、后处理 1、Game&#xff1a;CPU对游戏代码的处理工作 2、Draw&#xff1a;CPU为GPU准备数据所做的工作 3、GPU Time&#xff1a;就是GPU所渲染需要花的时间 UE5的命令行指令 里面说明了某个指令有什么用处 以及启动…

英伟达SSD视觉算法分类代码解析

一、官方原代码 #!/usr/bin/env python3 # # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Softwa…

【电路笔记】-电子放大器介绍

电子放大器介绍 文章目录 电子放大器介绍1、概述2、四极表示法3、理想模型4、真实放大器的限制5、噪音考虑因素6、电子放大器的类型1、概述 放大器是一种电子模块,可放大电位信号(电压放大器)、强度信号(电流放大器)或两者(功率放大器)。 放大器由两个输入组成,分别是…

开门预警系统技术规范(简化版)

开门预警系统技术规范(简化版) 1 系统概述2 预警区域3 预警目标4 功能需求5 功能条件6 显示需求7 指标需求1 系统概述 开门预警系统(DOW),在自车停止开门过程中,安装在车辆的传感器(如安装在车辆后保险杆两个角雷达)检测从自车后方接近的目标车(汽车、摩托车等)的相对…

Django面试题

1. 什么是wsgi&#xff1f; WSGI 是 “Web Server Gateway Interface” 的缩写&#xff0c;它是一种用于 Python Web 应用程序和 Web 服务器之间通信的标准接口。它定义了一组规则和约定&#xff0c;使 Web 服务器能够与任何符合 WSGI 规范的 Python Web 应用程序进行交互。 #…

2024年中级会计报名失败原因汇总❗

2024年中级会计报名失败原因汇总❗ ❌这四类考生不能报考24年中级⇩⇩⇩ 1️⃣不参加会计信息采集的同学 2️⃣未按规定完成继续教育的同学 3️⃣不符合会计工作年限要求的同学 4️⃣报名前未做好材料准备 需要准备有效期内身份证、本人学历或学位证书、户籍证或者居住证明、符…

翻转链表-链表题

LCR 141. 训练计划 III - 力扣&#xff08;LeetCode&#xff09; 非递归 class Solution { public:ListNode* trainningPlan(ListNode* head) {if(head ! nullptr && head->next ! nullptr){ListNode* former nullptr;ListNode* mid head;ListNode* laster nul…

C++ PDF转图片

C PDF转图片#include "include/fpdfview.h" #include <fstream> #include <include/core/SkImage.h>sk_sp<SkImage> pdfToImg(sk_sp<SkData> pdfData) {sk_sp<SkImage> img;FPDF_InitLibrary(nullptr);FPDF_DOCUMENT doc;FPDF_PAGE …

Character Region Awareness for Text Detection论文学习

​1.首先将模型在Synth80k数据集上训练 Synth80k数据集是合成数据集&#xff0c;里面标注是使用单个字符的标注的&#xff0c;也就是这篇文章作者想要的标注的样子&#xff0c;但是大多数数据集是成堆标注的&#xff0c;也就是每行或者一堆字体被整体标注出来&#xff0c;作者…

人工智能ChatGPT的多种应用:提示词工程

简介 ChatGPT 的主要优点之一是它能够理解和响应自然语言输入。在日常生活中&#xff0c;沟通本来就是很重要的一门课程&#xff0c;沟通的过程中表达的越清晰&#xff0c;给到的信息越多&#xff0c;那么沟通就越顺畅。 和 ChatGPT 沟通也是同样的道理&#xff0c;如果想要 …

33.星号三角阵(二)

上海市计算机学会竞赛平台 | YACSYACS 是由上海市计算机学会于2019年发起的活动,旨在激发青少年对学习人工智能与算法设计的热情与兴趣,提升青少年科学素养,引导青少年投身创新发现和科研实践活动。https://www.iai.sh.cn/problem/742 题目描述 给定一个整数 𝑛,输出一个…