宏AUROC与微AUROC的区别总结
特性 | 宏AUROC (Macro AUROC) | 微AUROC (Micro AUROC) |
---|---|---|
计算方式 | 对每个类别单独计算AUROC,然后取平均 | 将所有类别的结果展平后计算整体AUROC |
权重 | 每个类别的权重相同 | 样本较多的类别会占更大比重 |
适用场景 | 类别不平衡的数据集 | 类别平衡的数据集 |
偏向 | 强调所有类别的表现 | 偏向样本数量多的类别 |
优势 | 不受类别分布不平衡的影响 | 能较好反映整体样本的分类性能 |
相关代码的部分解读
import torch
from torch import nn # 导入PyTorch的神经网络模块
import lightning.pytorch as lp # 导入Lightning模块
from torch.utils.data import DataLoader # 数据加载器
from torchvision import transforms # 图像数据转换工具
import torch.nn.functional as F # 常用的函数操作,如激活函数
import os # 操作系统相关模块
import subprocess # 执行系统命令
from lightning.pytorch.tuner import Tuner # Lightning中的超参数调优器
from lightning.pytorch.loggers import TensorBoardLogger # 日志记录器
from lightning.pytorch.callbacks import (
ModelCheckpoint, # 模型检查点回调
LearningRateMonitor, # 学习率监控回调
TQDMProgressBar # 进度条显示回调
)
# 导入自定义的模型结构
from clinical_ts.xresnet1d import xresnet1d50, xresnet1d101
from clinical_ts.inception1d import inception1d
from clinical_ts.s4_model import S4Model
# 导入一些自定义工具函数和回调
from clinical_ts.misc_utils import add_default_args, LRMonitorCallback
#######################
# 特定的工具和函数导入
from clinical_ts.timeseries_utils import * # 时间序列相关的工具函数
from clinical_ts.schedulers import * # 学习率调度器
from clinical_ts.eval_utils_cafa import multiclass_roc_curve # 多分类ROC曲线计算
from clinical_ts.bootstrap_utils import * # 自举法工具
# 导入用于MIMIC数据集的预处理函数
from mimic_ecg_preprocessing import prepare_mimic_ecg
# 导入sklearn中的评估指标
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
# 文件路径相关工具
from pathlib import Path
import numpy as np # 数值计算库
import pandas as pd # 数据处理库
# 检查是否可以使用MLflow进行实验记录
MLFLOW_AVAILABLE = True
try:
import mlflow # MLflow库,用于记录实验和模型
import mlflow.pytorch # 与PyTorch集成的MLflow模块
import argparse # 解析命令行参数
# 定义一个函数,将`argparse.Namespace`对象转换为字典
def namespace_to_dict(namespace):
# 遍历命名空间中的所有变量并递归转换
return {
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
for k, v in vars(namespace).items()
}
except ImportError:
MLFLOW_AVAILABLE = False # 如果无法导入MLflow,设置为不可用
# 获取当前Git仓库的短哈希值,用于追踪代码版本
def get_git_revision_short_hash():
# 返回空字符串(此处简化实现)
return "" # 原始代码中可以通过Git命令获取版本信息
# 多标签编码函数,将标签转换为多热编码形式
def multihot_encode(x, num_classes):
# 初始化一个全0的数组
res = np.zeros(num_classes, dtype=np.float32)
# 将对应类别的位置设置为1
for y in x:
res[y] = 1
return res # 返回多热编码的结果
############################################################################################################
# 定义在全局范围内的函数,避免与Pickle的兼容性问题
# 计算多分类ROC曲线并返回结果的平坦数组形式
def mcrc_flat(targs, preds, classes):
# 计算多分类ROC曲线
_, _, res = multiclass_roc_curve(targs, preds, classes=classes)
# 将结果转为NumPy数组并返回
return np.array(list(res.values()))
# 准备标签的一致性映射(用于多层级的标签结构)
def prepare_consistency_mapping(codes_unique, codes_unique_all, propagate_all=False):
# 初始化空字典
res = {}
# 遍历所有唯一的标签代码
for c in codes_unique:
if propagate_all:
# 如果启用传播,将当前标签的所有前缀加入映射
res[c] = [c[:i] for i in range(3, len(c) + 1)]
else:
# 仅当类别存在时才传播
res[c] = np.intersect1d(
[c[:i] for i in range(3, len(c) + 1)], codes_unique_all
)
return res # 返回映射字典
class Main_ECG(lp.LightningModule):
# 定义一个继承自LightningModule的类,用于ECG(心电图)信号的训练与评估
def __init__(self, hparams):
# 初始化函数,用于初始化模型、损失函数以及其他超参数
super().__init__() # 调用父类的构造函数
self.save_hyperparameters(hparams) # 保存超参数
self.lr = self.hparams.lr # 从超参数中读取学习率
print(hparams) # 打印超参数,方便调试
# 根据微调数据集的类型确定类别数量
if(hparams.finetune_dataset == "thew"):
num_classes = 5
elif(hparams.finetune_dataset == "ribeiro_train"):
num_classes = 6
elif(hparams.finetune_dataset == "ptbxl_super"):
num_classes = 5
elif(hparams.finetune_dataset == "ptbxl_sub"):
num_classes = 24
elif(hparams.finetune_dataset == "ptbxl_all"):
num_classes = 71
elif(hparams.finetune_dataset.startswith("segrhythm")):
num_classes = int(hparams.finetune_dataset[9:]) # 截取类别数量
elif(hparams.finetune_dataset.startswith("rhythm")):
num_classes = int(hparams.finetune_dataset[6:]) # 截取类别数量
elif(hparams.finetune_dataset.startswith("mimic")):
# 使用自定义的`prepare_mimic_ecg`函数加载类别映射
_, lbl_itos = prepare_mimic_ecg(self.hparams.finetune_dataset, Path(self.hparams.data.split(",")[0]))
num_classes = len(lbl_itos) # 类别数量为标签映射的长度
# 根据数据集类型选择损失函数,二分类和多分类的损失函数不同
self.criterion = (
F.cross_entropy if (hparams.finetune_dataset == "thew" or
hparams.finetune_dataset.startswith("segrhythm"))
else F.binary_cross_entropy_with_logits
)
# 根据指定架构初始化模型
if(hparams.architecture == "xresnet1d50"):
self.model = xresnet1d50(input_channels=hparams.input_channels, num_classes=num_classes)
elif(hparams.architecture == "xresnet1d101"):
self.model = xresnet1d101(input_channels=hparams.input_channels, num_classes=num_classes)
elif(hparams.architecture == "inception1d"):
self.model = inception1d(input_channels=hparams.input_channels, num_classes=num_classes)
elif(hparams.architecture == "s4"):
# 初始化S4模型,支持自定义的多层结构
self.model = S4Model(
d_input=hparams.input_channels, d_output=num_classes,
l_max=self.hparams.input_size, d_state=self.hparams.s4_n,
d_model=self.hparams.s4_h, n_layers=self.hparams.s4_layers, bidirectional=True
)
else:
assert False # 若架构无效,触发异常
def forward(self, x, **kwargs):
# 前向传播函数,定义如何通过模型传递数据
# 将NaN值替换为0,避免错误
x[torch.isnan(x)] = 0
return self.model(x, **kwargs) # 返回模型的输出
def on_validation_epoch_end(self):
# 每个验证轮次结束时执行,将所有验证预测和标签传入评估函数
for i in range(len(self.val_preds)):
self.on_valtest_epoch_eval(
{"preds": self.val_preds[i], "targs": self.val_targs[i]},
dataloader_idx=i, test=False
)
# 清空预测和标签列表,准备下一轮
self.val_preds[i].clear()
self.val_targs[i].clear()
def on_test_epoch_end(self):
# 每个测试轮次结束时执行,将所有测试预测和标签传入评估函数
for i in range(len(self.test_preds)):
self.on_valtest_epoch_eval(
{"preds": self.test_preds[i], "targs": self.test_targs[i]},
dataloader_idx=i, test=True
)
# 清空预测和标签列表
self.test_preds[i].clear()
self.test_targs[i].clear()
def eval_scores(self, targs, preds, classes=None, bootstrap=False):
# 评估函数,计算多分类ROC曲线及AUC
_, _, res = multiclass_roc_curve(targs, preds, classes=classes)
if bootstrap:
# 使用自举法计算分数及置信区间
point, low, high, nans = empirical_bootstrap(
(targs, preds), mcrc_flat,
n_iterations=self.hparams.bootstrap_iterations,
score_fn_kwargs={"classes": classes}, ignore_nans=True
)
res2 = {}
# 将结果存入字典
for i, k in enumerate(res.keys()):
res2[k] = point[i]
res2[f"{k}_low"] = low[i]
res2[f"{k}_high"] = high[i]
res2[f"{k}_nans"] = nans[i]
return res2 # 返回自举法结果
return res # 返回普通评估结果
def on_valtest_epoch_eval(self, outputs_all, dataloader_idx, test=False):
# 计算单次验证或测试的评估结果
preds_all = torch.cat(outputs_all["preds"]).cpu()
targs_all = torch.cat(outputs_all["targs"]).cpu()
# 根据任务类型使用softmax或sigmoid进行归一化处理
if self.hparams.finetune_dataset == "thew" or self.hparams.finetune_dataset.startswith("segrhythm"):
preds_all = F.softmax(preds_all.float(), dim=-1)
targs_all = torch.eye(len(self.lbl_itos))[targs_all].to(preds_all.device)
else:
preds_all = torch.sigmoid(preds_all.float())
preds_all = preds_all.numpy()
targs_all = targs_all.numpy()
# 计算未聚合的AUC
res = self.eval_scores(targs_all, preds_all, classes=self.lbl_itos, bootstrap=test)
res = {f"{k}_auc_noagg_{'test' if test else 'val'}{dataloader_idx}": v for k, v in res.items()}
res = {k.replace("(", "_").replace(")", "_"): v for k, v in res.items()} # 避免MLflow错误
self.log_dict(res) # 记录结果
print(f"epoch {self.current_epoch} {'test' if test else 'val'} noagg:", res[f"macro_auc_noagg_{'test' if test else 'val'}{dataloader_idx}"])
# 聚合预测并重新评估
preds_all_agg, targs_all_agg = aggregate_predictions(
preds_all, targs_all,
self.test_idmaps[dataloader_idx] if test else self.val_idmaps[dataloader_idx],
aggregate_fn=np.mean
)
res_agg = self.eval_scores(targs_all_agg, preds_all_agg, classes=self.lbl_itos, bootstrap=test)
res_agg = {f"{k}_auc_agg_{'test' if test else 'val'}{dataloader_idx}": v for k, v in res_agg.items()}
res_agg = {k.replace("(", "_").replace(")", "_"): v for k, v in res_agg.items()}
self.log_dict(res_agg) # 记录聚合后的结果
# 如果测试时需要导出预测结果
if test and self.hparams.export_predictions_path != "":
df_test = pd.read_pickle(Path(self.hparams.export_predictions_path) / f"df_test{dataloader_idx}.pkl")
df_test["preds"] = list(preds_all_agg)
df_test["targs"] = list(targs_all_agg)
df_test.to_pickle(Path(self.hparams.export_predictions_path) / f"df_test{dataloader_idx}.pkl")
print(f"epoch {self.current_epoch} {'test' if test else 'val'} agg:", res_agg[f"macro_auc_agg_{'test' if test else 'val'}{dataloader_idx}"])
def setup(self, stage):
rhythm = self.hparams.finetune_dataset.startswith("rhythm")
# 判断数据集是否为节律(rhythm)相关数据集,如果是,则为多分类任务,需要读取类别数
if rhythm:
num_classes_rhythm = int(hparams.finetune_dataset[6:]) # 提取节律数据集的类别数量
# 配置数据集的参数
chunkify_train = self.hparams.chunkify_train # 是否将训练数据集块化处理
chunk_length_train = int(self.hparams.chunk_length_train * self.hparams.input_size) if chunkify_train else 0
# 计算训练时的块长度,如果不进行块化,则为0
stride_train = int(self.hparams.stride_fraction_train * self.hparams.input_size)
# 计算训练时的步幅
chunkify_valtest = True # 默认验证和测试数据集都进行块化
chunk_length_valtest = self.hparams.input_size if chunkify_valtest else 0
# 验证/测试数据集的块长度
stride_valtest = int(self.hparams.stride_fraction_valtest * self.hparams.input_size)
# 验证/测试数据集的步幅
# 初始化训练、验证和测试的数据集列表
train_datasets = []
val_datasets = []
test_datasets = []
# 初始化数据集的均值、标准差和标签映射为空
self.ds_mean = None
self.ds_std = None
self.lbl_itos = None
# 遍历所有指定的数据集文件夹
for i, target_folder in enumerate(list(self.hparams.data.split(","))):
target_folder = Path(target_folder) # 转换路径为Path对象
# 加载数据集,返回映射数据、标签映射表、均值和标准差
df_mapped, lbl_itos, mean, std = load_dataset(target_folder)
print("Folder:", target_folder, "Samples:", len(df_mapped)) # 打印数据集信息
# 初始化数据集的均值和标准差,如果没有设置过
if self.ds_mean is None:
if self.hparams.finetune_dataset.startswith("rhythm") or self.hparams.finetune_dataset.startswith("segrhythm"):
self.ds_mean = np.array([0., 0.]) # 对节律数据集使用默认均值和标准差
self.ds_std = np.array([1., 1.])
else:
# 使用PTB-XL数据集的默认均值和标准差
self.ds_mean = np.array([-0.00184586, -0.00130277, 0.00017031, -0.00091313,
-0.00148835, -0.00174687, -0.00077071, -0.00207407,
0.00054329, 0.00155546, -0.00114379, -0.00035649])
self.ds_std = np.array([0.16401004, 0.1647168, 0.23374124, 0.33767231,
0.33362807, 0.30583013, 0.2731171, 0.27554379,
0.17128962, 0.14030828, 0.14606956, 0.14656108])
# 如果是PTB-XL数据集,选择不同的标签映射
if self.hparams.finetune_dataset.startswith("ptbxl"):
if self.hparams.finetune_dataset == "ptbxl_super":
ptb_xl_label = "label_diag_superclass"
elif self.hparams.finetune_dataset == "ptbxl_sub":
ptb_xl_label = "label_diag_subclass"
elif self.hparams.finetune_dataset == "ptbxl_all":
ptb_xl_label = "label_all"
# 使用标签映射表,并进行多标签编码
lbl_itos = np.array(lbl_itos[ptb_xl_label])
df_mapped["label"] = df_mapped[ptb_xl_label + "_filtered_numeric"].apply(
lambda x: multihot_encode(x, len(lbl_itos))
)
elif self.hparams.finetune_dataset == "ribeiro_train":
# 过滤出有标签的数据,并进行多标签编码
df_mapped = df_mapped[df_mapped.strat_fold >= 0].copy()
df_mapped["label"] = df_mapped["label"].apply(
lambda x: multihot_encode(x, len(lbl_itos))
)
elif self.hparams.finetune_dataset.startswith("segrhythm"):
# 如果是节律分割任务,过滤出有效类别
num_classes_segrhythm = int(hparams.finetune_dataset[9:])
df_mapped = df_mapped[df_mapped.label.apply(lambda x: x < num_classes_segrhythm)]
lbl_itos = lbl_itos[:num_classes_segrhythm]
elif self.hparams.finetune_dataset.startswith("mimic"):
# 如果是MIMIC数据集,使用自定义函数准备数据
df_mapped, lbl_itos = prepare_mimic_ecg(
self.hparams.finetune_dataset, target_folder, df_mapped=df_mapped
)
# 如果标签映射表为空,则设置标签映射
if self.lbl_itos is None:
self.lbl_itos = lbl_itos[:num_classes_rhythm] if rhythm else lbl_itos
# 根据是否为节律数据集,配置数据转换(如是否进行标签多标签编码)
if rhythm:
if self.hparams.segmentation:
tfms_ptb_xl_cpc = ToTensor(transpose_label=True) # 使用分割任务的标签转换
else:
# 将标签转换为多标签格式
def annotation_to_multilabel(lbl):
lbl_unique = np.unique(lbl) # 提取唯一标签
lbl_unique = [x for x in lbl_unique if x < num_classes_rhythm]
return multihot_encode(lbl_unique, num_classes_rhythm)
tfms_ptb_xl_cpc = transforms.Compose([Transform(annotation_to_multilabel), ToTensor()])
else:
assert not self.hparams.segmentation # 确保非分割任务时不启用分割
tfms_ptb_xl_cpc = ToTensor() if not self.hparams.normalize else transforms.Compose(
[Normalize(self.ds_mean, self.ds_std), ToTensor()]
)
# 根据数据的fold字段划分训练、验证和测试集
max_fold_id = df_mapped.fold.max() # 获取最大fold ID
df_train = df_mapped[df_mapped.fold < max_fold_id - 1]
df_val = df_mapped[df_mapped.fold == max_fold_id - 1]
df_test = df_mapped[df_mapped.fold == max_fold_id]
# 创建TimeseriesDatasetCrops对象,并将其添加到对应的数据集列表中
train_datasets.append(TimeseriesDatasetCrops(
df_train, self.hparams.input_size, data_folder=target_folder,
chunk_length=chunk_length_train, min_chunk_length=self.hparams.input_size,
stride=stride_train, transforms=tfms_ptb_xl_cpc, col_lbl="label",
memmap_filename=target_folder / "memmap.npy"
))
val_datasets.append(TimeseriesDatasetCrops(
df_val, self.hparams.input_size, data_folder=target_folder,
chunk_length=chunk_length_valtest, min_chunk_length=self.hparams.input_size,
stride=stride_valtest, transforms=tfms_ptb_xl_cpc, col_lbl="label",
memmap_filename=target_folder / "memmap.npy"
))
test_datasets.append(TimeseriesDatasetCrops(
df_test, self.hparams.input_size, data_folder=target_folder,
chunk_length=chunk_length_valtest, min_chunk_length=self.hparams.input_size,
stride=stride_valtest, transforms=tfms_ptb_xl_cpc, col_lbl="label",
memmap_filename=target_folder / "memmap.npy"
))
# 如果指定了预测导出路径,则保存标签映射和测试数据集
if self.hparams.export_predictions_path != "":
np.save(Path(self.hparams.export_predictions_path) / "lbl_itos.npy", self.lbl_itos)
df_test.to_pickle(Path(self.hparams.export_predictions_path) / f"df_test{len(test_datasets) - 1}.pkl")
print("\n", target_folder)
print(f"train dataset: {len(train_datasets[-1])} samples")
print(f"val dataset: {len(val_datasets[-1])} samples")
print(f"test dataset: {len(test_datasets[-1])} samples")
# 如果有多个数据集,则将它们合并为一个整体
if len(train_datasets) > 1:
print("\nCombined:")
self.train_dataset = ConcatDatasetTimeseriesDatasetCrops(train_datasets)
self.val_datasets = [ConcatDatasetTimeseriesDatasetCrops(val_datasets)] + val_datasets
# 如果只有一个数据文件夹,则直接设置训练、验证和测试集
else: # just a single data folder
self.train_dataset = train_datasets[0] # 直接使用单个训练数据集
self.val_datasets = val_datasets # 设置验证数据集
self.test_datasets = test_datasets # 设置测试数据集
# 为验证和测试的预测和标签创建空列表
self.val_preds = [[] for _ in range(len(self.val_datasets))]
# 创建与验证数据集数量相同的空列表用于存储验证预测
self.val_targs = [[] for _ in range(len(self.val_datasets))]
# 创建与验证数据集数量相同的空列表用于存储验证标签
self.test_preds = [[] for _ in range(len(self.test_datasets))]
# 创建与测试数据集数量相同的空列表用于存储测试预测
self.test_targs = [[] for _ in range(len(self.test_datasets))]
# 创建与测试数据集数量相同的空列表用于存储测试标签
# 存储id映射,用于聚合预测
self.val_idmaps = [ds.get_id_mapping() for ds in self.val_datasets]
# 为每个验证数据集存储其ID映射,用于聚合预测
self.test_idmaps = [ds.get_id_mapping() for ds in self.test_datasets]
# 为每个测试数据集存储其ID映射
def train_dataloader(self):
# 返回训练数据加载器
return DataLoader(
self.train_dataset, batch_size=self.hparams.batch_size,
num_workers=8, shuffle=True, drop_last=True
)
# 使用PyTorch的DataLoader加载训练数据,启用多线程加速,并打乱数据
def val_dataloader(self):
# 返回验证数据加载器列表,每个数据集一个加载器
return [
DataLoader(ds, batch_size=self.hparams.batch_size, num_workers=8)
for ds in self.val_datasets
]
def test_dataloader(self):
# 返回测试数据加载器列表,每个数据集一个加载器
return [
DataLoader(ds, batch_size=self.hparams.batch_size, num_workers=8)
for ds in self.test_datasets
]
def _step(self, data_batch, batch_idx, train, test=False, dataloader_idx=0):
# 执行单个批次的训练、验证或测试步骤
preds_all = self.forward(data_batch[0]) # 通过模型前向传播得到预测
loss = self.criterion(preds_all, data_batch[1]) # 计算损失
self.log(
"train_loss" if train else ("test_loss" if test else "val_loss"), loss
) # 日志记录损失
# 根据训练/测试阶段,保存预测和标签
if not train and not test:
self.val_preds[dataloader_idx].append(preds_all.detach())
# 保存验证预测
self.val_targs[dataloader_idx].append(data_batch[1])
# 保存验证标签
elif not train and test:
self.test_preds[dataloader_idx].append(preds_all.detach())
# 保存测试预测
self.test_targs[dataloader_idx].append(data_batch[1])
# 保存测试标签
return loss # 返回损失
def training_step(self, train_batch, batch_idx):
# 执行单个训练步骤
return self._step(train_batch, batch_idx, train=True)
def validation_step(self, val_batch, batch_idx, dataloader_idx=0):
# 执行单个验证步骤
return self._step(val_batch, batch_idx, train=False, test=False, dataloader_idx=dataloader_idx)
def test_step(self, test_batch, batch_idx, dataloader_idx=0):
# 执行单个测试步骤
return self._step(test_batch, batch_idx, train=False, test=True, dataloader_idx=dataloader_idx)
def configure_optimizers(self):
# 根据超参数选择优化器
if self.hparams.optimizer == "sgd":
opt = torch.optim.SGD
elif self.hparams.optimizer == "adam":
opt = torch.optim.AdamW
else:
raise NotImplementedError("Unknown Optimizer.") # 不支持的优化器类型
params = self.parameters() # 获取模型参数
# 初始化优化器
optimizer = opt(params, self.lr, weight_decay=self.hparams.weight_decay)
# 根据学习率调度策略设置学习率调度器
if self.hparams.lr_schedule == "const":
scheduler = get_constant_schedule(optimizer)
elif self.hparams.lr_schedule == "warmup-const":
scheduler = get_constant_schedule_with_warmup(
optimizer, self.hparams.lr_num_warmup_steps
)
elif self.hparams.lr_schedule == "warmup-cos":
scheduler = get_cosine_schedule_with_warmup(
optimizer, self.hparams.lr_num_warmup_steps,
self.hparams.epochs * len(self.train_dataloader()), num_cycles=0.5
)
elif self.hparams.lr_schedule == "warmup-cos-restart":
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer, self.hparams.lr_num_warmup_steps,
self.hparams.epochs * len(self.train_dataloader()),
num_cycles=self.hparams.epochs - 1
)
elif self.hparams.lr_schedule == "warmup-poly":
scheduler = get_polynomial_decay_schedule_with_warmup(
optimizer, self.hparams.lr_num_warmup_steps,
self.hparams.epochs * len(self.train_dataloader()),
num_cycles=self.hparams.epochs - 1
)
elif self.hparams.lr_schedule == "warmup-invsqrt":
scheduler = get_invsqrt_decay_schedule_with_warmup(
optimizer, self.hparams.lr_num_warmup_steps
)
elif self.hparams.lr_schedule == "linear":
scheduler = get_linear_schedule_with_warmup(
optimizer, 0, self.hparams.epochs * len(self.train_dataloader())
)
else:
assert False # 不支持的调度策略
# 返回优化器和调度器的配置
return (
[optimizer],
[{'scheduler': scheduler, 'interval': 'step', 'frequency': 1}]
)
def load_weights_from_checkpoint(self, checkpoint):
""" 从检查点文件加载模型权重 """
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
# 加载检查点文件
pretrained_dict = checkpoint["state_dict"] # 提取权重字典
model_dict = self.state_dict() # 获取模型的当前状态字典
# 过滤出与模型匹配的权重
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict) # 更新模型的状态字典
self.load_state_dict(model_dict) # 加载新的状态字典
def load_state_dict(self, state_dict):
# 兼容S4模型的load_state_dict函数
for name, param in self.named_parameters():
param.data = state_dict[name].data.to(param.device) # 将参数加载到对应设备
for name, param in self.named_buffers():
param.data = state_dict[name].data.to(param.device) # 将缓存也加载到对应设备
######################################################################################################
# MISC
######################################################################################################
def load_from_checkpoint(pl_model, checkpoint_path):
"""
从检查点文件加载权重的函数,兼容 S4 模型。
"""
# 加载检查点文件,将其映射到合适的存储位置(如 CPU 或 GPU)
lightning_state_dict = torch.load(checkpoint_path)
# 从检查点中提取出模型的状态字典
state_dict = lightning_state_dict["state_dict"]
# 遍历模型中的所有参数
for name, param in pl_model.named_parameters():
# 将检查点中的参数数据加载到当前模型的参数中
param.data = state_dict[name].data
# 遍历模型中的所有缓冲区(如 BatchNorm 的均值和方差等)
for name, param in pl_model.named_buffers():
# 将检查点中的缓存数据加载到当前模型的缓存区中
param.data = state_dict[name].data
#####################################################################################################
# ARGPARSER:定义与模型和应用相关的命令行参数解析器函数
#####################################################################################################
# 定义与模型结构相关的参数
def add_model_specific_args(parser):
# 添加参数 --input-channels:输入的通道数,默认为12(如ECG信号通常有12导联)
parser.add_argument("--input-channels", type=int, default=12)
# 添加参数 --architecture:指定模型的架构,可以选择 xresnet1d50, xresnet1d101, inception1d, s4
# 默认为 xresnet1d50
parser.add_argument("--architecture", type=str, help="xresnet1d50/xresnet1d101/inception1d/s4", default="xresnet1d50")
# 添加参数 --s4-n:S4模型的参数N,默认为8(Sashimi论文中的默认值为64)
parser.add_argument("--s4-n", type=int, default=8, help='S4: N (Sashimi default:64)')
# 添加参数 --s4-h:S4模型的隐藏层大小H,默认为512(Sashimi中的默认值为64)
parser.add_argument("--s4-h", type=int, default=512, help='S4: H (Sashimi default:64)')
# 添加参数 --s4-layers:S4模型中的层数,默认为4(Sashimi默认值为8)
parser.add_argument("--s4-layers", type=int, default=4, help='S4: number of layers (Sashimi default:8)')
# 添加参数 --s4-batchnorm:布尔型参数,使用BatchNorm而不是LayerNorm,如果添加该参数则启用BatchNorm
parser.add_argument("--s4-batchnorm", action='store_true', help='S4: use BN instead of LN')
# 添加参数 --s4-prenorm:布尔型参数,启用S4的prenorm(预归一化)机制
parser.add_argument("--s4-prenorm", action='store_true', help='S4: use prenorm')
# 返回解析器对象,包含所有添加的模型相关参数
return parser
# 定义与应用逻辑相关的参数
def add_application_specific_args(parser):
# 添加参数 --normalize:布尔型参数,如果启用则根据数据集的统计信息对输入进行归一化
parser.add_argument("--normalize", action='store_true', help='Normalize input using dataset stats')
# 添加参数 --finetune-dataset:指定微调使用的数据集,默认值为 ptbxl_all
parser.add_argument("--finetune-dataset", type=str, help="...", default="ptbxl_all")
# 添加参数 --chunk-length-train:设置训练时每次读取数据块的长度(以输入大小的倍数表示),默认值为1.0
parser.add_argument("--chunk-length-train", type=float, default=1., help="training chunk length in multiples of input size")
# 添加参数 --stride-fraction-train:训练时的数据滑动步幅(以输入大小的倍数表示),默认值为1.0
parser.add_argument("--stride-fraction-train", type=float, default=1., help="training stride in multiples of input size")
# 添加参数 --stride-fraction-valtest:验证和测试时的数据滑动步幅(以输入大小的倍数表示),默认值为1.0
parser.add_argument("--stride-fraction-valtest", type=float, default=1., help="val/test stride in multiples of input size")
# 添加参数 --chunkify-train:布尔型参数,启用训练时的数据块化处理(chunkify)
parser.add_argument("--chunkify-train", action='store_true')
# 添加参数 --segmentation:布尔型参数,启用分割任务(segmentation)
parser.add_argument("--segmentation", action='store_true')
# 添加参数 --eval-only:用于指定模型检查点的路径,仅进行模型评估,默认为空字符串
parser.add_argument("--eval-only", type=str, help="path to model checkpoint for evaluation", default="")
# 添加参数 --bootstrap-iterations:设置用于分数估计的自举(bootstrap)迭代次数,默认值为1000
parser.add_argument("--bootstrap-iterations", type=int, help="number of bootstrap iterations for score estimation", default=1000)
# 添加参数 --export-predictions-path:设置导出预测结果的路径,默认为空字符串(即不导出)
parser.add_argument("--export-predictions-path", type=str, default="", help="path to directory to export predictions")
# 返回解析器对象,包含所有添加的应用相关参数
return parser
###################################################################################################
#MAIN
###################################################################################################
###################################################################################################
# MAIN: 这部分是脚本的入口
###################################################################################################
if __name__ == '__main__':
parser = add_default_args() # 调用自定义函数`add_default_args()`,初始化参数解析器
parser = add_model_specific_args(parser) # 添加与模型相关的命令行参数
parser = add_application_specific_args(parser) # 添加与应用程序逻辑相关的命令行参数
hparams = parser.parse_args() # 解析传入的命令行参数并存储在`hparams`对象中
hparams.executable = "main_ecg" # 设置可执行文件名为`main_ecg`
hparams.revision = get_git_revision_short_hash() # 获取当前Git的短哈希值,跟踪代码版本
# 如果只用于评估(eval-only),将训练的epoch数设为0
if(hparams.eval_only != ""):
hparams.epochs = 0
# 如果指定的输出路径不存在,则创建该路径
if not os.path.exists(hparams.output_path):
os.makedirs(hparams.output_path)
# 初始化`Main_ECG`模型,并将超参数传入
model = Main_ECG(hparams)
# 初始化TensorBoard日志记录器,设置日志保存的目录
logger = TensorBoardLogger(
save_dir=hparams.output_path, # 设置日志的保存路径
name="" # 在TensorBoard中没有给特定的日志名称
)
# 输出日志的路径
print("Output directory:", logger.log_dir)
# 如果MLflow可用,设置MLflow的实验,并启用自动记录模型的参数和指标
if(MLFLOW_AVAILABLE):
mlflow.set_experiment(hparams.executable) # 设置MLflow实验名称
mlflow.pytorch.autolog(log_models=False) # 自动记录,但不记录模型本身
# 配置模型检查点保存的回调函数
checkpoint_callback = ModelCheckpoint(
dirpath=logger.log_dir, # 设置模型保存路径
filename="best_model", # 模型文件的名称
save_top_k=1, # 保存表现最好的一个模型
save_last=True, # 保存最后一个epoch的模型
verbose=True, # 打印保存信息
monitor="macro_auc_agg_val0", # 监控`macro_auc_agg_val0`指标
mode='max' # 模型指标越大越好
)
# 学习率监控器回调,记录每步的学习率
lr_monitor = LearningRateMonitor(logging_interval="step")
# 定义回调函数列表,包含模型检查点和学习率监控器
callbacks = [checkpoint_callback, lr_monitor]
# 如果需要显示进度条,则添加TQDM进度条回调
if(hparams.refresh_rate > 0):
callbacks.append(TQDMProgressBar(refresh_rate=hparams.refresh_rate))
# 初始化PyTorch Lightning的训练器`Trainer`,传入超参数和回调函数
trainer = lp.Trainer(
num_sanity_val_steps=0, # 不进行验证集的检查步骤(调试用)
accumulate_grad_batches=hparams.accumulate, # 梯度累积的步数
max_epochs=hparams.epochs, # 最大训练轮数
min_epochs=hparams.epochs, # 最小训练轮数
default_root_dir=hparams.output_path, # 设置训练的根目录
logger=logger, # 使用初始化的TensorBoard日志记录器
callbacks=callbacks, # 使用指定的回调函数列表
benchmark=True, # 开启性能基准测试
accelerator="gpu" if hparams.gpus > 0 else "cpu", # 判断是否使用GPU
devices=hparams.gpus if hparams.gpus > 0 else 1, # 设置使用的GPU数量
num_nodes=hparams.num_nodes, # 使用的节点数量
precision=hparams.precision, # 设置计算精度(16位或32位)
enable_progress_bar=hparams.refresh_rate > 0 # 控制是否启用进度条
)
# 如果启用自动批次大小调整,使用Tuner调整批次大小
if(hparams.auto_batch_size):
tuner = Tuner(trainer) # 初始化Tuner
tuner.scale_batch_size(model, mode="binsearch") # 二分搜索找到合适的批次大小
# 如果启用学习率查找,则使用Tuner查找最优学习率
if(hparams.lr_find):
tuner = Tuner(trainer) # 初始化Tuner
lr_finder = tuner.lr_find(model) # 查找学习率
# 如果有训练任务并且不只是用于评估
if(hparams.epochs > 0 and hparams.eval_only == ""):
# 如果MLflow可用,则在MLflow中开始一个新的运行
if(MLFLOW_AVAILABLE):
with mlflow.start_run(run_name=hparams.metadata) as run:
# 记录超参数到MLflow
for k, v in dict(hparams._get_kwargs()).items():
mlflow.log_param(k, " " if v == "" else v) # 处理空字符串的特殊情况
# 开始训练,并保存表现最好的模型
trainer.fit(model, ckpt_path=None if hparams.resume == "" else hparams.resume)
# 在测试集上测试模型
trainer.test(model, ckpt_path="best")
else:
# 如果没有MLflow,直接训练和测试模型
trainer.fit(model, ckpt_path=None if hparams.resume == "" else hparams.resume)
trainer.test(model, ckpt_path="best")
# 如果指定了`eval_only`参数,只进行测试
elif(hparams.eval_only != ""):
# 如果MLflow可用,在MLflow中记录测试运行
if(MLFLOW_AVAILABLE):
with mlflow.start_run(run_name=hparams.metadata) as run:
# 记录超参数到MLflow
for k, v in dict(hparams._get_kwargs()).items():
mlflow.log_param(k, " " if v == "" else v) # 处理空字符串的特殊情况
# 测试模型,使用指定的检查点文件
trainer.test(model, ckpt_path=hparams.eval_only)
else:
# 如果没有MLflow,直接测试模型
trainer.test(model, ckpt_path=hparams.eval_only)
改进思路
1.外部验证
2.为什么用2.5秒
3.只用宏AUROC来评测对罕见病具有局限性
4.加入更多的临床元数据,比如血常规
今日进展:
先研究数据集构建
相关代码部分:
import numpy as np # 导入NumPy库,用于数值计算
import pandas as pd # 导入Pandas库,用于数据处理
def multihot_encode(x, num_classes):
# 创建一个长度为num_classes的全0数组,数据类型为float32
res = np.zeros(num_classes, dtype=np.float32)
# 遍历x中的每个标签,将对应位置设置为1
for y in x:
res[y] = 1
# 返回多热编码后的数组
return res
############################################################################################################
def prepare_consistency_mapping(codes_unique, codes_unique_all, propagate_all=False):
# 初始化一个空字典,用于存储一致性映射结果
res = {}
# 遍历所有唯一的代码
for c in codes_unique:
if propagate_all:
# 如果启用全传播,将当前标签的所有前缀存入字典
res[c] = [c[:i] for i in range(3, len(c) + 1)]
else:
# 否则,仅当前缀在所有代码中存在时才映射
res[c] = np.intersect1d(
[c[:i] for i in range(3, len(c) + 1)], codes_unique_all
)
# 返回映射结果字典
return res
def prepare_mimic_ecg(finetune_dataset, target_folder, df_mapped=None, df_diags=None):
'''解析微调数据集的参数,并加载MIMIC数据集'''
# 定义一个工具函数,用于将嵌套列表展平
def flatten(l):
return [item for sublist in l for item in sublist]
# 保存初始的诊断数据框
df_diags_initial = df_diags
# 从微调数据集名称中提取各部分
subsettrain = finetune_dataset.split("_")[1]
labelsettrain = finetune_dataset.split("_")[2]
subsettest = finetune_dataset.split("_")[3]
labelsettest = finetune_dataset.split("_")[4]
min_cnt = int(finetune_dataset.split("_")[5])
# 检查是否有位数限制
if len(finetune_dataset.split("_")) < 7:
digits = None
propagate_all = False
else:
digits = finetune_dataset.split("_")[6]
if digits[-1] == "A":
propagate_all = True
digits = int(digits[:-1])
else:
propagate_all = False
digits = int(digits)
# 如果诊断数据已传入,则使用它,否则从文件中加载
if df_diags is not None:
df_diags = df_diags
else:
if (target_folder / "records_w_diag_icd10.pkl").exists():
df_diags = pd.read_pickle(target_folder / "records_w_diag_icd10.pkl")
else:
df_diags = pd.read_csv(target_folder / "records_w_diag_icd10.csv")
df_diags.drop('Unnamed: 0', axis=1, inplace=True)
df_diags['ecg_time'] = pd.to_datetime(df_diags["ecg_time"])
df_diags['dod'] = pd.to_datetime(df_diags["dod"])
for c in ['ed_diag_ed', 'ed_diag_hosp', 'hosp_diag_hosp', 'all_diag_hosp', 'all_diag_all']:
df_diags[c] = df_diags[c].apply(lambda x: eval(x))
# 根据用户选择的标签集设置训练和测试标签
if labelsettrain.startswith("hosp"):
df_diags["label_train"] = df_diags["all_diag_hosp"]
labelsettrain = labelsettrain[len("hosp"):]
elif labelsettrain.startswith("ed"):
df_diags["label_train"] = df_diags["ed_diag_ed"]
labelsettrain = labelsettrain[len("ed"):]
elif labelsettrain.startswith("all"):
df_diags["label_train"] = df_diags["all_diag_all"]
labelsettrain = labelsettrain[len("all"):]
else:
assert False
if labelsettest.startswith("hosp"):
df_diags["label_test"] = df_diags["all_diag_hosp"]
elif labelsettest.startswith("ed"):
df_diags["label_test"] = df_diags["ed_diag_ed"]
elif labelsettest.startswith("all"):
df_diags["label_test"] = df_diags["all_diag_all"]
else:
assert False
# 检查每个样本是否有ICD标签
df_diags["has_statements_train"] = df_diags["label_train"].apply(lambda x: len(x) > 0)
df_diags["has_statements_test"] = df_diags["label_test"].apply(lambda x: len(x) > 0)
# 如果指定了位数限制,则截断ICD代码
if digits is not None:
df_diags["label_train"] = df_diags["label_train"].apply(
lambda x: list(set([y.strip()[:digits] for y in x]))
)
df_diags["label_test"] = df_diags["label_test"].apply(
lambda x: list(set([y.strip()[:digits] for y in x]))
)
# 删除标签中的占位符X
df_diags["label_train"] = df_diags["label_train"].apply(lambda x: list(set([y.rstrip("X") for y in x])))
df_diags["label_test"] = df_diags["label_test"].apply(lambda x: list(set([y.rstrip("X") for y in x])))
# 根据用户指定的标签集进行过滤
if labelsettrain == "af":
df_diags["label_train"] = df_diags["label_train"].apply(lambda x: [c for c in x if c.startswith("I48")])
elif labelsettrain != "":
df_diags["label_train"] = df_diags["label_train"].apply(lambda x: [c for c in x if c[0] in labelsettrain])
# 使用一致性映射处理标签
col_flattrain = flatten(np.array(df_diags["label_train"]))
cons_maptrain = prepare_consistency_mapping(np.unique(col_flattrain), np.unique(col_flattrain), propagate_all)
df_diags["label_train"] = df_diags["label_train"].apply(
lambda x: list(set(flatten([cons_maptrain[y] for y in x])))
)
# 多热编码训练和测试标签
lbl_itos = np.unique(flatten(df_diags["label_train"]))
lbl_stoi = {s: i for i, s in enumerate(lbl_itos)}
df_diags["label_train"] = df_diags["label_train"].apply(
lambda x: multihot_encode([lbl_stoi[y] for y in x if y in lbl_itos], len(lbl_itos))
)
# 返回处理后的数据和标签映射
return df_diags, lbl_itos
下载好了有关的数据集
标记一下各文件夹的作用
edstays文件夹:
病人住院情况在 edstays 表中跟踪。edstays 表中的每一行都有一个唯一的 stay_id,代表急诊室中唯一的病人住院时间。edstays 表包含以下列:subject_id、hadm_id、stay_id、intime 和 outtime。住院时间(intime)表示患者被急诊室收治的时间,出院时间(outtime)表示患者从急诊室出院的时间。如果患者在急诊室住院后又入院,则 hadm_id 栏将填入代表其住院时间的标识符。hadm_id 可与 MIMIC-IV 中的 hadm_id 相链接,以获取患者住院时间的更多详情。最后,每个人都有一个唯一的 subject_id,在 EDstays 表中,有多次 ED 住院经历的病人在不同的住院经历中会有相同的 subject_id。请注意,subject_id 可以与 MIMIC-IV 链接,以获取病人的人口统计数据。subject_id 还可以与 MIMIC-CXR 中的 PatientID DICOM 属性链接,以获取病人的胸部 X 光片(如果拍过的话)。
diagnosis文件夹:
诊断表提供病人在国际疾病分类(ICD)第九版或第十版(ICD-9 或 ICD-10)中的编码诊断。这些诊断由训练有素的编码员在患者出院后确定,用于医院计费。诊断表中有六列:subject_id、stay_id、seq_num、icd_code、icd_version 和 icd_title。单次住院最多可使用 9 个 ICD 代码。seq_num 列提供了 ICD 代码的伪排序,值 1 通常表示相关性最高,值 9 表示相关性最低。icd_code 列提供使用 ICD 本体的诊断编码表示,icd_version 列为 9 或 10,表示使用的本体是 ICD-9 还是 ICD-10,icd_title 列提供 ICD 代码的文字描述。
medrecon文件夹:
medrecon 表提供每位患者的药物对账,即患者在急诊室住院前服用的药物列表。
medrecon 表有九列:subject_id、stay_id、charttime、name、gsn、ndc、etc_rn、etccode 和 etcdescription。图表时间提供记录药品调节的日期和时间。name 栏提供药品的文字描述,gsn 栏提供通用序列号 (GSN),ndc 栏提供国家药品代码 (NDC)。请注意,如果 gsn 或 ndc 为 0,则表示该值缺失。以 etc 为前缀的列提供了一个本体,用于将类似类别的药品归为一类。请注意,由于一种药物可在本体中分为多个组,因此一种药物可能有不止一行。例如,药物 Adderal (1) 是一种中枢神经系统兴奋剂,(2) 是一种注意力缺陷-多动疗法,(3) 是一种嗜睡症疗法。因此,入院前服用阿德拉的患者在 medrecon 表中会有三行,分别由顺序单调递增的整数 etc_rn 划分。etccode 提供本体组的编码形式,etcdescription 提供本体组的文本描述。
pyxis文件夹:
pyxis 表提供急诊室自动配药系统 BD Pyxis MedStation 的配药信息。
pyxis 表有九列:subject_id、stay_id、charttime、med_rn、name、gsn_rn 和 gsn。图表时间提供了配药的时间。如果同时配发了多种药物,med_rn 列会对这些药物进行划分。name(名称)列提供配药的文字描述,还可能包含配方等辅助信息。gsn 列提供通用序列号(GSN)(如果有的话),gsn_rn 列出与同一种药品相关的多个 GSN 值。请注意,gsn 为 0 表示缺少 GSN。
triage文件夹:
分诊表提供分诊时收集的病人信息。
所有到急诊室就诊的病人都会立即被分流,这一过程包括评估病人的健康状况和确定就诊原因。分诊表有 11 列:subject_id、stay_id、temperature、heartrate、respirrate、o2sat、sbp、dbp、pain、acuity 和 chiefcomplaint。分诊时收集的生命体征包括患者体温(temperature)、心率(heartrate)、呼吸频率(resprate)、血氧饱和度(o2sat)、收缩压(sbp)和舒张压(dbp)。虽然生命体征可以自由文本形式记录,但去标识化方法只保留了数字生命体征。疼痛一栏提供了患者报告的疼痛程度。主诉(chiefcomplaint)是一个自由文本字段,包含患者报告的到急诊室就诊的原因。主诉字段通常是以逗号分隔的条目列表。主诉字段中的 PHI 已被三个下划线(‘____’)取代。根据分诊评估,护理提供者将指定一个整数的严重程度(敏锐度),其中 1 表示最高严重程度,5 表示最低严重程度。
vitalsign文件夹:
vitalsign 表包含病人住院期间记录的非周期性生命体征。vitalsign 表有 11 列:subject_id、stay_id、charttime、temperature、heartrate、respirrate、o2sat、sbp、dbp、rhythm 和 pain。生命体征表中的生命体征与分诊表中收集的生命体征类似。心律栏还提供了病人的心律。图表时间提供了记录生命体征的时间。