介绍
clustlasso是结合lasso和cluster-lasso策略的R包,并发表在Interpreting k-mer based signatures for antibiotic resistance prediction。
标准交叉验证lasso分类或回归流程如下:
- 选择交叉验证数据集(数据分割);
- 选择最佳模型(训练参数);
- 测试集评估模型效能(确定最终模型);
通过看源代码发现相比标准的lasso聚类或回归它多了一个cluster的过程,通过比较自变量之间的相关系数大小进行聚类分析。
加载R包和数据
上gitlab下载该包的tar.gz文件,然后本地安装软件(可适用于windows和Linux)。
install.packages("NMF")
install.packages("D:/Downloads/clustlasso-master.tar.gz", repos = NULL, type = "source")
suppressWarnings(suppressMessages(library(clustlasso)))
加载所需要数据
# specify / set random seed
seed = 42
set.seed(seed)
# load example dataset
input.file = system.file("data", "NG-dataset.Rdata", package = "clustlasso")
load(input.file)
以及80%和20%切割数据集合
# pick 20% for test
test.frac = 0.2
# stratify by origin / population structure
ind.by.struct = split(seq(nrow(meta)), meta$pop_structure)
ind.sample = sapply(ind.by.struct, function(x) {
sample(x, round(test.frac * length(x)))
})
ind.test = unlist(ind.sample)
# split
X.test = X[ind.test, ]
y.test = y[ind.test]
meta.test = meta[ind.test, ]
X.train = X[-ind.test, ]
y.train = y[-ind.test]
meta.train = meta[-ind.test, ]
标准交叉验证lasso过程
该过程没有使用cluster方法。
- 选择交叉验证数据集(数据分割);
- 选择最佳模型(训练参数);
- 测试集评估模型效能(确定最终模型);
Cross-validattion process
交叉验证的目的是训练模型参数,调参的对象是lasso模型的lambda参数。可以设置n.folds和n.repeat参数。
# specify cross-validation parameters
n.folds = 10
n.lambda = 100
n.repeat = 3
# run cross-validation process
cv.res.lasso = lasso_cv(X.train, y.train, subgroup = meta.train$pop_structure,
n.lambda = n.lambda, n.folds = n.folds, n.repeat = n.repeat,
seed = seed, verbose = FALSE)
最佳参数展示show_cv_overall(modsel.criterion+best.eps)。模型标准和最佳特征均展示出来。
par(mfcol = c(1, 3))
show_cv_overall(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1)
Selecting the best model
最佳模型根据modsel.criterion参数确定,该参数可根据auc和balanced.accuracy.best确定。
layout(matrix(c(1, 2, 3), nrow = 1, byrow = TRUE), width = c(0.3, 0.3, 0.4), height = c(1))
perf.best.lasso = show_cv_best(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1, method = "lasso")
# print cross-validation performance of best model
print(perf.best.lasso)
提取最佳模型extract_best_model.
best.model.lasso = extract_best_model(cv.res.lasso, modsel.criterion = "balanced.accuracy.best", best.eps = 1)
Making predictions and measuring performance
根据上一步选择的最佳模型应用于测试集,进而评估模型的效能。
# make predictions
preds.lasso = predict_clustlasso(X.test, best.model.lasso)
# compute performance
perf.lasso = compute_perf(preds.lasso$preds, preds.lasso$probs, y.test)
# print
print(t(perf.lasso$perf))
可视化结果
par(mfcol = c(1, 2))
plot(perf.lasso$roc.curves[[1]], lwd = 2, main = "lasso - test set ROC curve")
grid()
plot(perf.lasso$pr.curves[[1]], lwd = 2, main = "lasso - test set precision / recall curve")
grid()
总结:调参后选择最佳参数确定最终模型对分类器构建至关重要,这里选择balanced.accuracy.best而没有选择auc(大家可以试试auc的结果如何)。
Cluster-lasso过程
与上面标准lasso流程类似,但增加了cluster过程。
Cross-validattion process
该过程多增加了screen.thresh和clust.thresh,该参数用于cluster过程。
# specify cross-validation parameters
n.folds = 10
n.lambda = 100
n.repeat = 3
# specify screening and clustering thresholds
screen.thresh = 0.95
clust.thresh = 0.95
# run cross-validation process
cv.res.cluster = clusterlasso_cv(X.train, y.train, subgroup = meta.train$pop_structure,
n.lambda = n.lambda, n.folds = n.folds, n.repeat = n.repeat,
seed = seed, screen.thresh = screen.thresh, clust.thresh = clust.thresh,
verbose = FALSE)
par(mfcol = c(1, 3))
show_cv_overall(cv.res.cluster, modsel.criterion = "balanced.accuracy.best",
best.eps = 1)
Selecting the best model
layout(matrix(c(1, 2, 3, 4, 4, 4), nrow = 2, byrow = TRUE), width = c(0.3,0.3, 0.4), height = c(0.6, 0.4))
perf.best.cluster = show_cv_best(cv.res.cluster, modsel.criterion = "balanced.accuracy.best",
best.eps = 1, method = "clusterlasso")
# print cross-validation performance of best model
print(perf.best.cluster)
best.model.cluster = extract_best_model(cv.res.cluster, modsel.criterion = "balanced.accuracy.best",
best.eps = 1, method = "clusterlasso")
Making predictions and measuring performance
# make predictions
preds.cluster = predict_clustlasso(X.test, best.model.cluster,
method = "clusterlasso")
# compute performance
perf.cluster = compute_perf(preds.cluster$preds, preds.cluster$probs, y.test)
# print
print(t(perf.cluster$perf))
比较两类方法的结果
比较standard lasso和cluster-lasso 方法在测试集上的预测效能以及特征的区别。
ROC曲线
plot(perf.lasso$roc.curves[[1]], lwd = 2, main = "test set ROC curves")
points(1 - (perf.lasso$perf$speci)/100, perf.lasso$perf$sensi/100, pch = 19, col = 1, cex = 1.25)
plot(perf.cluster$roc.curves[[1]], lwd = 2, col = 2, add = TRUE)
points(1 - (perf.cluster$perf$speci)/100, perf.cluster$perf$sensi/100,
pch = 17, col = 2, cex = 1.25)
grid()
abline(0, 1, lty = 2)
legend("bottomright", c("lasso", "cluster-lasso"), col = c(1, 2), lwd = 2)
特征
heatmap_correlation_signatures(X, best.model.lasso, best.model.cluster,
clust.min = 5, plot.title = "features correlation matrix")
Note: 最上面橘色和蓝色分布表示lasso和cluster-lasso选择出来的特征,两者重叠部分较多。从热图聚类结果看,聚类效果和cluster-lasso分类结果类似。