第100+12步 ChatGPT学习:R实现KNN分类

基于R 4.2.2版本演示

一、写在前面

有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。

答曰:可!用GPT或者Kimi转一下就得了呗。

加上最近也没啥内容写了,就帮各位搬运一下吧。

二、R代码实现KNN分类

(1)导入数据

我习惯用RStudio自带的导入功能:

(2)建立KNN模型

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)

# Fit KNN model on the training set
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl, preProcess = "scale")

# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]

# Convert true values to factor for ROC analysis
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)

# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "blue") +
  geom_area(alpha = 0.2, fill = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Training ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")

ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "red") +
  geom_area(alpha = 0.2, fill = "red") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Validation ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")

# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)

# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {
  conf_mat_df <- as.data.frame(as.table(conf_mat))
  colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  
  p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
    geom_tile(color = "white") +
    geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
    scale_fill_gradient(low = "white", high = "steelblue") +
    labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  
  print(p)
}
# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")

# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]

a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]

# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc

# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc

# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")

cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

在R语言中,caret包提供了一个通用的接口来训练KNN模型。使用caret的train函数来训练KNN模型时,可以调整多种参数来优化模型的性能:

基本参数:

①formula: 指定模型的公式,如Y ~ .,表示使用数据框中的所有其他变量来预测Y。

②data: 提供包含训练数据的数据框。

③method: 对于KNN模型,这个参数应设置为"knn"。

④preProcess: 预处理步骤,常用的包括标准化("scale")和中心化("center"),对于KNN这一步非常重要因为KNN依赖于变量的距离度量。

⑤trControl: 一个trainControl对象,定义了模型训练的各种控制策略,如交叉验证的类型和重复次数。

trainControl 函数的参数:

①method: 训练的方法,如交叉验证("cv"),重复交叉验证("repeatedcv"),留一交叉验证("LOOCV")等。

②number: 对于"cv"和"repeatedcv",这个参数定义了折数。

③repeats: 当使用"repeatedcv"时,定义重复的次数。

④search: 参数搜索方法,默认为"grid"。也可以设置为"random"进行随机搜索。

⑤savePredictions: 是否保存预测结果,通常用于后续分析。

模型性能调整参数:

使用KNN时,最关键的参数之一是邻居的数量(K值)。这可以通过train函数的以下参数来调整:

①tuneLength: 这个参数决定了在参数搜索中考虑多少个不同的K值。

②tuneGrid: 这是一个数据框,可以自定义K值的具体范围,例如expand.grid(k = c(1, 5, 10))

结果输出(默认参数):

三、KNN调参方法

如前所述,KNN的关键参数就是K值,所以可以对其进行一个暴力测试,比如取值1到10:

# 定义交叉验证的控制方法,启用网格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定义K值的网格搜索范围
tuneGrid <- expand.grid(k = 1:10)
# 在训练集上拟合KNN模型,指定网格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,
               tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型结果,找出最优的K值
print(model)

解读:

定义交叉验证的控制方法:使用trainControl函数设定交叉验证的详细参数。

定义K值的网格:使用tuneGrid参数在train函数中指定K值的范围。

拟合模型:使用train函数训练模型,同时应用预处理步骤(比如标准化数据),以确保每个特征在距离计算中具有等同的权重。

结果输出:

注意:用了caret包的train函数,并且通过网格搜索指定了一系列的参数(如K值的范围),那么这个函数会自动选择表现最好的参数配置来训练最终的模型。train函数的输出即是基于你提供的训练数据和参数搜索范围内表现最优的模型。因此,当你调用predict函数进行预测时,使用的就是这个最优化的模型。所以,下面的代码不变。

结果吧,跟之前的完全一样:

因为caret包对于KNN模型默认进行一系列的K值尝试,通常这个范围是1到最多的邻居数,但具体的最大K值依赖于caret的内部设置。在大多数情况下,它会尝试如1, 5, 7, 9等常用的K值。所以,我们默认参数的时候,其实软件自动给我们寻找最优K值了。可以用这个代码输出最有K值:

# Print the best K value used by the model
best_k <- model$bestTune$k
cat("The best K value found is:", best_k, "\n")

K值就是9,跟我们自行调参的一致。

那我们猛点,把K的范围设置的宽一些:

# 定义交叉验证的控制方法,启用网格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定义K值的网格搜索范围
tuneGrid <- expand.grid(k = 1:20)
# 在训练集上拟合KNN模型,指定网格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,
               tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型结果,找出最优的K值
print(model)

结果:

K=19,性能指标如下,似乎大同小异:

四、最后

数据嘛:

链接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm

提取码:x8xm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/722524.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

API 设计技巧:基础知识与实践的方法

在这篇深入探讨中&#xff0c;我们将从基础开始&#xff0c;逐步介绍 API 设计&#xff0c;并探讨定义卓越API的最佳实践。 作为一名开发者&#xff0c;你可能已经熟悉了许多这些概念&#xff0c;但我将提供详细解释&#xff0c;以加深你的理解。 API 设计&#xff1a;电子商…

图纸管理最佳实践:从混乱到有序的转变

图纸管理最佳实践&#xff1a;从混乱到有序的转变 在工程项目中&#xff0c;图纸是不可或缺的重要资料&#xff0c;它们承载着设计者的智慧与心血。然而&#xff0c;随着项目的推进和时间的推移&#xff0c;图纸管理往往变得混乱不堪&#xff0c;给项目的顺利进行带来诸多不便。…

ROS(四)

write in advance 实验四&#xff0c;在经历了实验三的失败后&#xff0c;卷土重来。 幸运的是&#xff0c;此次实验很简单&#xff0c;很快就能搞定。 Part one 使用指令查看自己摄像头的设备号&#xff0c;如果报错&#xff0c;且你为虚拟机&#xff0c;可能是未在虚拟机处…

Java学习 (二)关键字、标识符、数组

一、关键字 我们第一章案例中有很多关键字&#xff0c;比如class、public、static、void等&#xff0c;这些关键字依旧被java定义好了&#xff0c;可以拿来用&#xff0c;不需要死记硬背&#xff0c;按照官方文档查询即可 #官方文档 https://docs.oracle.com/javase/tutorial/j…

docker-compose设置永久启动、自动重启

步骤一 找到 docker-compose.yml 文件 步骤二 vim 打开文件 找到 image: PS&#xff1a;就是为了对齐格式 步骤三 在其下方添加&#xff1a; restart: always而后保存即可

打开nginx连接的php页面报错502

目录 问题描述&#xff1a; 原因&#xff1a; 1. 使用 Unix 域套接字&#xff08;Unix Socket&#xff09; 区别和优势&#xff1a; 2. 使用 TCP/IP 套接字 区别和优势&#xff1a; 如何选择 扩展&#xff1a;Rocky_Linux9.4安装PHP的步骤&#xff1a; 使用Remi存储库…

Wifi通信协议:WEP,WPA,WPA2,WPA3,WPS

前言 无线安全性是保护互联网安全的重要因素。连接到安全性低的无线网络可能会带来安全风险&#xff0c;包括数据泄露、账号被盗以及恶意软件的安装。因此&#xff0c;利用合适的Wi-Fi安全措施是非常重要的&#xff0c;了解WEP、WPA、WPA2和WPA3等各种无线加密标准的区别也是至…

通过防抖动代码解决ResizeObserver loop completed with undelivered notifications.

通过防抖动代码解决ResizeObserver loop completed with undelivered notifications. 一、报错内容二、解决方案解释&#xff1a; 一、报错内容 我通过el-tabs下的el-tab-pane切换到el-table出现的报错&#xff0c;大致是渲染宽度出现了问题 二、解决方案 扩展原生的 Resiz…

Android 应用加固与重签名—使用AndroidStudio自带工具 apksigner

由 AndroidStudio 生成的release版本的app有自己的签名&#xff0c;但当应用加固后会删除原签名&#xff0c;需要重新签名。 一、加固方式&#xff1a; 使用基础版的腾讯云&#xff08;乐固&#xff09;进行免费加固&#xff0c;上传软件后等待在线加固完成后下载 apk 即可。…

vue3+ts+vite集成eslint

项目中安装eslint yarn add eslint -Deslint初始化 npx eslint --init按照下方操作即可 安装typescript-eslint/parser yarn add typescript-eslint/parser -D安装vite-plugin-eslint2 yarn add vite-plugin-eslint2 -D配置vite-plugin-eslint2 // vite.config.ts import …

钡铼技术BL104在环境监测站多协议采集保障数据全面准确

随着工业化和城市化进程的加快&#xff0c;环境污染问题日益严重&#xff0c;环境监测站在保护生态环境、保障公众健康中的作用变得越来越重要。钡铼技术PLC物联网关BL104&#xff0c;为环境监测站提供了一种高效、可靠的多协议数据采集解决方案&#xff0c;保障了监测数据的全…

如何实现对文件发送全生命周期的外发管理?

在日常工作中&#xff0c;我们需要经常和企业外部的机构或者个人&#xff0c;发送一些企业内部的文档或者图纸等资料。但企业在文件外发管理上&#xff0c;仍存在一定漏洞&#xff0c;有些员工会通过一些手段&#xff0c;将重要核心文件数据发过去&#xff0c;包括但不仅限于以…

【数据结构(邓俊辉)学习笔记】二叉搜索树02——查找、插入和删除

文章目录 1.概述2. 查找2.1 查找&#xff1a;算法2.2 查找&#xff1a;理解2.3 查找&#xff1a;实现2.4 查找&#xff1a;语义 3. 插入3.1 插入&#xff1a;算法3.2 插入&#xff1a;实现 4. 删除4.1 删除&#xff1a;框架4.2 删除&#xff1a;单分支4.3 删除&#xff1a;双分…

MySQL----redo log重做日志原理及流程

重做日志 redo log&#xff1a;重做日志&#xff0c;用于记录事务操作的变化&#xff0c;确保事务的持久性。redo log是在事务开始后就开始记录&#xff0c;不管事务是否提交都会记录下来&#xff0c;在异常发生时&#xff08;如数据持久化过程中掉电&#xff09;&#xff0c;…

笔记-python里面的xlrd模块详解

那我就一下面积个问题对xlrd模块进行学习一下&#xff1a; 1.什么是xlrd模块&#xff1f; 2.为什么使用xlrd模块&#xff1f; 3.怎样使用xlrd模块&#xff1f; 1.什么是xlrd模块&#xff1f; ♦python操作excel主要用到xlrd和xlwt这两个库&#xff0c;即xlrd是读excel&…

成都百洲文化传媒有限公司助力商家扬帆远航

在数字经济的浪潮中&#xff0c;电商行业如日中天&#xff0c;成都百洲文化传媒有限公司正是这一领域的佼佼者。作为一家专注于电商服务的传媒公司&#xff0c;百洲文化以其专业的服务、创新的理念和卓越的成果&#xff0c;在业内树立了良好的口碑&#xff0c;成为众多商家信赖…

大数据实训室建设可行性报告

一、建设大数据实训室的背景与意义 随着信息技术的飞速发展&#xff0c;大数据已成为推动社会进步和经济发展的重要力量。中高职院校作为技能型人才培养的摇篮&#xff0c;承担着为社会输送大数据领域高素质、高技能人才的重要任务。因此&#xff0c;建设大数据实训室&#xf…

OKR:2024年目标和关键成果常见问题

什么是目标和关键结果&#xff08;OKR&#xff09;&#xff1f; 目标和关键结果&#xff08;#OKR#&#xff09;是一种由结果驱动的目标制定方法。在企业中&#xff0c;OKR经常被用来指导基于结果的成功。使用结果而不是任务作为驱动力&#xff0c;OKRs 鼓励通过度量指标对实现成…

Android 查询及获取应用程序 Package 及 Acticity 名称的方法

一、通过命令查询和获取应用的Package及Acticity &#xff08;一&#xff09;通过命令查询应用包名及安装信息 以下是常用命令&#xff1a; 命令形式作用adb shell pm list packages查询系统中所有应用的包名adb shell pm list packages -s查询系统应用包名adb shell pm lis…

onnx进阶算子优化

一、定义 如何保证pytorch 模型顺利转为onnx. 前言pytorch 算子是如何与onnx 算子对齐的&#xff1f;Asinh 算子出现于第 9 个 ONNX 算子集。PyTorch 在 9 号版本的符号表文件中是怎样支持这个算子的&#xff1f;BitShift 算子出现于第11个 ONNX 算子集。PyTorch 在 11 号版本…