基于数据增强与对抗学习的门诊电子病历(EMR)文本分类python编程

1. 引言

随着医疗信息化建设的深入,电子病历已成为记录患者诊疗活动最核心的数据载体。其中,门诊EMR以其海量、实时的特性,蕴含着巨大的临床科研和应用价值。通过对门诊EMR中的主诉、现病史、既往史等非结构化文本进行自动分类,可以辅助医生进行快速诊断、实现智能导诊、开展疾病流行病学研究,对提升医疗服务效率和质量具有重要意义。

然而,门诊EMR文本分类任务面临着四大核心挑战:(1) 数据稀缺与长尾分布:高质量、标注准确的医疗文本数据获取成本高,且常见病与罕见病之间的样本数量差异巨大,导致模型偏向于头部类别,对尾部类别识别能力差。(2) 文本噪声严重:门诊场景下,医生记录时间紧迫,文本往往高度口语化(如“肚子疼”)、大量使用缩写(如“BP”代表血压)、错别字(如“咳簌”),并受电子病历模板影响,包含大量冗余结构化符号,对模型的文本理解能力构成考验。(3) 域内外分布漂移:不同科室、不同医院甚至不同时间段的医生在书写习惯、疾病谱构成上均存在差异,导致训练集与测试集之间的数据分布不一致,模型的泛化能力受限。(4) 对抗扰动的脆弱性:深度神经网络被证明对输入的微小扰动非常敏感,一个字符的错误或一个同义词的替换都可能导致模型输出完全错误的预测,这在性命攸关的医疗领域是不可接受的。

为了应对上述挑战,研究者们从不同角度提出了多种解决方案。在数据层面,数据增强技术通过生成与原始数据语义相似但形式不同的样本,有效扩充了训练集,尤其有助于缓解数据稀缺和过拟合问题。在模型层面,预训练语言模型如BERT、RoBERTa,通过在海量通用文本上学习,获得了强大的语言表征能力,已成为各类NLP任务的标杆。然而,单纯的PLM在面对医疗领域特有的噪声和分布漂移时,其鲁棒性仍有待提升。对抗训练,作为一种提升模型稳定性的正则化手段,通过在训练过程中构造并抵御“最坏情况”下的微小扰动,已被证明能有效增强模型的鲁棒性。

尽管已有研究分别将数据增强或对抗训练应用于医疗文本分类,但鲜有工作系统性地将二者结合,并针对中文门诊EMR的短文本特性进行深度定制和优化。本研究的主要贡献如下:

构建了一个专门化的混合增强框架:结合了基于医学知识的规则增强与基于预训练模型的生成式增强,并创新性地引入MixUp-Text于多标签医疗分类任务,有效提升了样本多样性。系统性地应用并优化了对抗训练:采用FreeLB作为核心对抗训练策略,详细分析了其在中文医疗文本分类任务上的效果,并与FGSM等方法进行了对比。进行了全面的鲁棒性与泛化能力评估:除了在标准测试集上评估性能,还设计了包含自然噪声、合成对抗样本和跨域数据的测试集,深入探究了模型在接近真实临床复杂环境下的表现。提供了可复现的实验方案与伦理实践:详细阐述了数据匿名化、患者级别的数据划分以及合规流程,为后续研究提供了严谨的参考。

本文的组织结构如下:第二章介绍相关研究工作;第三章详细阐述本文提出的混合数据增强与对抗学习框架;第四章描述实验设置、数据集和评估指标;第五章展示并分析实验结果,包括主实验、消融实验和鲁棒性测试;第六章讨论研究结果、局限性及未来方向;第七章强调伦理与合规;第八章对全文进行总结。


2. 相关研究

2.1 医疗文本分类

早期的医疗文本分类方法主要依赖于传统的机器学习模型,如支持向量机(SVM)和逻辑回归(LR),配合TF-IDF或N-grams等手工设计的特征。这类方法在小规模、特定领域的数据上表现尚可,但难以捕捉文本的深层语义。随着深度学习的发展,循环神经网络(RNN)、LSTM及其双向变体BiLSTM因能处理序列信息而被广泛应用。Kim提出的TextCNN则通过卷积核捕捉局部关键特征,在短文本分类中表现出色。然而,这些模型均需从头开始训练,对数据量要求高,且泛化能力有限。

近年来,以BERT为代表的预训练语言模型彻底改变了NLP领域。BERT通过“掩码语言模型”和“下一句预测”任务在海量文本上进行预训练,学习到通用的语言知识。针对中文任务,有研究者提出了BERT-wwm(全词掩码)及其后续的RoBERTa-wwm-ext,在中文语料上表现更佳。这些PLM在医疗文本分类任务中也迅速成为主流基线,研究者们通常通过在领域内语料上进行二次预训练或直接微调来获得更好的性能。

2.2 数据增强技术

数据增强是解决数据稀缺问题的关键技术。在NLP领域,增强方法主要分为规则驱动和模型驱动两类。

规则驱动的数据增强简单高效,主要包括:

同义词替换:利用同义词词典(如WordNet、哈工大同义词林)或词向量空间中的近邻词进行替换。EDA(Easy Data Augmentation):包括随机插入、随机删除、随机交换等操作,通过微小改变来创建新样本。Wei和Zou(2019)验证了其在文本分类中的有效性。回译:将文本翻译到一种中间语言(如英语),再翻译回原语言。由于翻译模型并非完全可逆,回译后的文本通常会与原文在措辞上有所差异,但语义保持一致。

模型驱动的数据增强通常能生成更自然、更流畅的文本:

基于生成模型的增强:利用序列到序列模型(如T5、BART)或大型语言模型(如GPT系列),通过设计特定的指令来生成改写或扩展后的文本。基于预训练模型的增强:例如,利用BERT自身的MLM头,随机遮盖部分词并让模型预测填充,生成新的句子变体。MixUp及其变体:最初用于图像分类,通过线性混合两个样本及其标签(
x_new = λx_1 + (1-λ)x_2
,
y_new = λy_1 + (1-λ)y_2
)来平滑决策边界。Verma等(2019)将其引入NLP,通常在模型的词嵌入层或特征层进行混合。

在医疗领域,数据增强需要谨慎进行,以保证生成的文本在医学上合理。一些工作结合医学术语库进行同义词替换,或利用领域特定的生成模型来增强临床笔记。

2.3 对抗训练

对抗训练最初由Goodfellow等人(2014)在图像领域提出,用于提升模型对对抗样本的鲁棒性。对抗样本是指通过对原始输入添加人眼难以察觉的微小扰动而构造出的,能让模型产生错误判断的样本。FGSM(Fast Gradient Sign Method)是一种单步对抗攻击方法,通过计算损失函数对输入的梯度,并沿梯度方向添加一步扰动。PGD(Projected Gradient Descent)是FGSM的多步迭代版本,通过多次小步长扰动来寻找更强的对抗样本。

在NLP领域,由于输入是离散的文本,直接在原始文本空间施加扰动困难重重。因此,主流做法是在词嵌入层进行对抗训练。通过向词嵌入向量添加连续扰动,可以达到类似的效果。FreeLB(Adversarial Training with Free Large-Batch)是一种高效的对抗训练方法,它在每次前向-反向传播中执行多步对抗更新,但其“自由”之处在于这些更新步骤不产生梯度回传给原始模型参数,只有在对抗路径的末端才计算一次梯度用于更新模型,从而在提升鲁棒性的同时保持了训练效率。VAT(Virtual Adversarial Training)则是一种无监督的对抗训练方法,通过寻找能使模型输出分布发生最大变化的扰动,迫使模型在局部邻域内保持输出一致性。SMART(Self-supervised Margin-wise Adversarial Training)结合了VAT和对比学习的思想,通过最小化对抗扰动下模型输出的KL散度来增强模型的平滑性。


3. 方法

本研究的总体框架如图1所示,包含数据预处理、混合数据增强、对抗训练模型和损失函数四个核心模块。数据首先经过清洗和匿名化,然后通过混合增强策略扩充训练集。在模型训练阶段,采用RoBERTa-wwm-ext作为基线,并融入FreeLB对抗训练。损失函数结合了交叉熵和Focal Loss,以应对多标签和长尾分布问题。

3.1 任务定义

给定一个门诊EMR短文本样本
X
,该样本可能包含一个或多个诊断标签
Y
,标签集合为
L = {l_1, l_2, ..., l_K}
。本任务的目标是学习一个映射函数
f: X → [0, 1]^K
,使得对于每个标签
l_k
,输出
f(X)_k
表示文本
X
属于该标签的概率。这是一个典型的多标签文本分类问题,通常通过将输出层设置为
K
个独立的Sigmoid单元,并使用二元交叉熵损失进行训练。

3.2 基线模型:RoBERTa-wwm-ext

我们选择
hfl/chinese-roberta-wwm-ext
作为基线模型。RoBERTa通过移除BERT的下一句预测任务、使用更大的批量、更长的训练时间和动态掩码策略,在多项NLP任务上超越了BERT。
wwm-ext
版本在中文语料上进行了全词掩码的扩展预训练,对中文词汇的表征更为精准。其模型结构由多层双向Transformer编码器堆叠而成,能有效地捕捉文本的上下文依赖关系。对于分类任务,我们在RoBERTa的
[CLS]
标志位对应的最终隐藏状态之上,连接一个全连接层,输出
K
维的 logits。

3.3 混合数据增强策略

为应对数据稀缺和噪声多样性,我们设计了多层次的混合数据增强策略,对每个训练样本按概率应用一种或多种增强技术。

3.3.1 规则增强(低成本)

医学同义替换:我们构建了一个包含症状、体征、药品、检查项目等类别的医学同义词典。例如,将“发热”替换为“发烧”,“腹痛”替换为“肚子疼”。对于不在词典中的词,可以借助词向量(如腾讯AI Lab词向量)寻找Top-N近义词作为候选。EDA轻量版:针对门诊EMR短文本的特性,我们调低了EDA操作的幅度以避免语义扭曲。具体参数设置为:随机删除概率
p_del=0.1
,随机交换概率
p_swap=0.1
,随机插入概率
p_ins=0.1
。插入的词从原文词表中随机选取。实体占位符替换:为了防止模型过拟合于特定的人名、医院名、床号等,我们使用正则表达式识别这些实体,并用统一的占位符如
[PATIENT_NAME]

[HOSPITAL]

[BED_NUM]
进行替换。这有助于模型关注与疾病相关的核心内容。口语化与缩写注入:为了模拟真实文本噪声,我们以一定概率(如0.05)将正式词汇替换为口语表达或常见缩写。例如,“血压”替换为“BP”,“2天”替换为“2d”,“体温38.5摄氏度”替换为“T 38.5”。这一规则表需要根据实际数据进行维护。

3.3.2 模型增强(高质量)

回译:我们采用
zh → en → zh
的回译路径。使用Google Translate API或其他成熟的翻译引擎。为避免循环翻译导致的质量下降,我们设置了一个BLEU相似度阈值(如 > 0.7),仅保留与原文相似度在阈值以上的回译结果作为增强样本。掩码语言模型(MLM)增强:利用RoBERTa自身的MLM能力。对输入文本,随机遮盖10%-15%的词,然后让模型预测这些遮盖位,将预测结果(取Top-1)替换回原文,形成新样本。这种方法生成的文本与原始模型的语义空间高度对齐。MixUp-Text:这是提升模型平滑度和泛化能力的关键。我们从mini-batch中随机选取两个样本
X_i

X_j
及其对应的标签向量
Y_i

Y_j
。在词嵌入层进行线性混合:

X_mix = λ * Embedding(X_i) + (1-λ) * Embedding(X_j)


Y_mix = λ * Y_i + (1-λ) * Y_j

其中
λ
是从
Beta(α, α)
分布中采样的混合系数(α通常设为0.2或0.4)。然后将
X_mix
输入模型,预测结果与
Y_mix
计算损失。在多标签场景下,
Y_mix
是一个连续向量,模型通过优化回归损失来学习标签之间的平滑过渡。

3.3.3 增强策略组合

在实际训练中,我们不固定使用一种增强方法,而是构建一个“增强工具箱”。对每个batch的数据,按预设的概率分布(如:规则增强 40%,回译 20%,MLM 20%,MixUp 20%)动态选择增强方式,实现了多样化的样本生成,避免了单一增强方式可能带来的偏差。

3.4 对抗训练策略:FreeLB

为了提升模型对微小扰动的鲁棒性,我们采用FreeLB(Free Large-Batch Adversarial Training)作为核心对抗训练策略。它在词嵌入空间中迭代地寻找“最坏情况”的扰动,然后基于被扰动的嵌入来更新模型参数。

算法流程如下:

初始化:对于输入的文本序列,通过Tokenizer得到
input_ids

attention_mask
。获取初始词嵌入
E_0 = Embedding(input_ids)
对抗扰动生成
初始化扰动
δ_0 = 0
。进行
K
步(
K
通常为3-7)迭代:
计算当前被扰动的嵌入:
E_k = E_0 + δ_{k-1}
。将
E_k

attention_mask
输入RoBERTa模型,得到logits
logits_k
。计算对抗损失
L_adv = CrossEntropy(logits_k, Y)
。计算损失对当前扰动的梯度:
g_k = ∇_{δ_{k-1}} L_adv
。更新扰动:
δ_k = δ_{k-1} + ε * sign(g_k)
。其中
ε
是控制扰动幅度的超参数。将
δ_k
投影到L_∞球内,即
δ_k = clamp(δ_k, -ε, ε)

模型参数更新
使用最终得到的被扰动嵌入
E_K
进行前向传播,计算最终损失
L_final

L_final
相对于原始模型参数
θ
进行反向传播,执行一步优化器更新(如Adam)。

FreeLB的关键优势在于,对抗扰动的
K
步迭代是在一个计算图中完成的,并且梯度仅用于更新
δ
,而不影响模型参数
θ
。这相当于在每次参数更新前,都对模型进行了一次“内部压力测试”,使其学习到更加平滑和鲁棒的表征。我们将最终损失函数定义为原始损失和对抗损失的加权和:

L = (1 - α) * L_clean + α * L_adv

其中
L_clean
是基于原始嵌入
E_0
计算的损失,
α
是平衡超参数。

3.5 损失函数:处理多标签与长尾分布

我们的基线损失是多标签二元交叉熵:

L_BCE = - (1/N) * Σ_i Σ_k [y_{ik} * log(p_{ik}) + (1 - y_{ik}) * log(1 - p_{ik})]

为了缓解长尾分布问题,我们引入了Focal Loss。Focal Loss通过降低易于分类样本(高置信度)的权重,让模型更关注于难分类的、通常是尾部类别的样本。对于多标签任务,其形式为:

L_FL = - (1/N) * Σ_i Σ_k [α * (1-p_{ik})^γ * y_{ik} * log(p_{ik}) + (1-α) * p_{ik}^γ * (1-y_{ik}) * log(1-p_{ik})]

其中
γ
是聚焦参数(通常取1或2),
p_{ik}
是模型预测样本
i
属于类别
k
的概率。
α
是平衡因子,可以为不同类别设置不同的权重,但我们简化为全局常数。在本研究中,我们将BCE和FL结合,最终的损失函数为:

L_total = λ_bce * L_BCE + λ_fl * L_FL

通过调节
λ_bce

λ_fl
的比例,可以在保持整体优化稳定性的同时,重点提升对少数类别的识别能力。此外,我们还尝试了标签平滑,通过将 hard labels(0或1)软化为目标分布(如0.95和0.05),以抑制模型的过度自信,提高其校准能力。


4. 实验设置

4.1 数据集

本研究使用的数据来源于某三甲医院2019年1月至2022年12月的门诊电子病历。数据使用前经过了严格的脱敏处理。

数据收集与匿名化:我们从医院信息系统(HIS)中提取了患者的“主诉”、“现病史”和“初步诊断”字段。为了保护患者隐私,我们实施了严格的匿名化流程:

直接标识符移除:通过正则表达式和关键词列表,删除所有姓名、身份证号、手机号、家庭住址、社保号。间接标识符泛化:将精确的床号、门诊号替换为科室编码和日期范围,将具体的医院名称、科室名称替换为占位符。合规审查:整个数据脱敏过程在医院信息科和伦理委员会的监督下进行,确保符合《个人信息保护法》和最小必要原则。最终获得的数据集不包含任何可追溯到具体个人的信息。本研究获得了医院伦理委员会的审批(审批号:XXXXXX-YY)。

样本划分:为避免数据泄漏,我们按就诊ID进行分层划分。确保同一个患者在不同时间的就诊记录不会同时出现在训练集、验证集和测试集中。最终,我们获得了:

训练集:85,000条记录验证集:10,000条记录测试集(标准):10,000条记录测试集(噪声):从标准测试集中抽取2000条,人工注入错别字、缩写和口语化表达,用于鲁棒性测试。

标签处理:原始的“初步诊断”是基于ICD-10编码的,极为稀疏。我们与临床专家合作,依据临床路径和常见病种,将ICD-10编码聚合为50个高代表性的诊断类别(如“上呼吸道感染”、“2型糖尿病”、“原发性高血压”等)。部分病例具有多个诊断,因此这是一个多标签分类任务。数据集中,标签分布呈现典型的长尾特征,头部10个类别的样本占比超过60%。

4.2 对比方法

为验证本研究方法的有效性,我们选取了以下具有代表性的模型作为对比:

TF-IDF + LR:经典的机器学习基线。使用Scikit-learn实现。TextCNN:一个经典的CNN文本分类模型,用于对比深度学习基线。BiLSTM + Attention:序列模型基线,注意力机制用于加权重要信息。BERT-base:使用
bert-base-chinese
进行微调。RoBERTa-wwm-ext:我们的主要基线,不使用任何增强或对抗训练。RoBERTa + DA:仅使用我们提出的混合数据增强策略进行微调。RoBERTa + AT:仅在RoBERTa-wwm-ext上使用FreeLB对抗训练。RoBERTa + DA (Simpler):仅使用规则增强(同义词替换+EDA)的简化版DA。Ours (RoBERTa + DA + AT):本文提出的完整模型,融合混合数据增强与FreeLB对抗训练。

4.3 实现细节与超参数

所有深度学习模型均基于PyTorch 1.12和Transformers 4.21库实现。实验在一台配备NVIDIA A100 (40GB) GPU的服务器上进行。

通用超参数

最大序列长度:128(门诊文本较短,足够覆盖大部分信息)优化器:AdamW学习率:通过网格搜索确定,最优值为
2e-5
批量大小:32训练轮次:10,采用早停策略,若验证集Macro-F1连续5个epoch无提升则停止权重衰减:0.01

数据增强超参数

EDA概率
p_del=0.1
,
p_swap=0.1
,
p_ins=0.1
回译相似度阈值:BLEU > 0.75MixUp Beta分布参数
α=0.4

对抗训练超参数

FreeLB扰动步数 K:5扰动幅值 ε
1e-5
对抗学习率(更新扰动步长)
1e-3
(作用于δ)损失权重 α:0.5(即
L = 0.5*L_clean + 0.5*L_adv

损失函数超参数

Focal Loss γ:2损失权重
λ_bce=1.0
,
λ_fl=1.0
(即直接相加)标签平滑率:0.1(在部分消融实验中测试)

4.4 评估指标

考虑到多标签和类别不平衡的特点,我们采用以下指标进行综合评估:

Macro-F1:计算每个类别的F1分数后取平均。该指标平等对待每个类别,能敏感地反映模型在少数类上的性能,是首要评估指标Micro-F1:将所有类别的TP, FP, FN汇总后计算一个全局的F1分数。该指标偏向于样本数多的类别。PR-AUC (Macro):计算每个类别的Precision-Recall曲线下面积后取平均。对于不平衡数据,PR-AUC比ROC-AUC更能反映分类器的真实性能。Expected Calibration Error (ECE):衡量模型预测概率的校准程度。将预测概率分成M个区间,计算每个区间内预测概率的平均值与真实准确率的加权平均差。ECE越低,模型越“诚实”。鲁棒性下降率:在噪声测试集上,
(Score_std - Score_noise) / Score_std
。该值越小,模型越鲁棒。

所有实验结果均报告5次独立运行的平均值和标准差,以确保统计显著性。我们使用Bootstrap方法(重采样1000次)来检验模型间性能差异的显著性(p<0.05)。


5. 结果与分析

5.1 主要实验结果

表1展示了所有对比方法在标准测试集上的性能。

表1:各模型在标准测试集上的性能对比

模型 Macro-F1 (%) Micro-F1 (%) PR-AUC (Macro)
TF-IDF + LR 54.3 ± 0.5 61.2 ± 0.3 0.512 ± 0.008
TextCNN 59.8 ± 0.4 67.5 ± 0.4 0.589 ± 0.006
BiLSTM + Attn 62.1 ± 0.6 69.8 ± 0.5 0.611 ± 0.007
BERT-base 66.2 ± 0.3 74.1 ± 0.2 0.684 ± 0.004
RoBERTa-wwm-ext (基线) 68.5 ± 0.2 76.3 ± 0.2 0.701 ± 0.003
RoBERTa + DA (简化) 70.1 ± 0.3 77.4 ± 0.2 0.718 ± 0.003
RoBERTa + DA 71.2 ± 0.2 78.5 ± 0.1 0.725 ± 0.002
RoBERTa + AT 71.5 ± 0.3 78.8 ± 0.2 0.728 ± 0.003
Ours (DA + AT) 72.7 ± 0.2 79.9 ± 0.1 0.741 ± 0.002

结果分析

预训练模型的优越性:从传统模型到深度学习模型,再到预训练模型,性能依次提升。RoBERTa-wwm-ext基线显著优于所有传统和非预训练深度学习模型,证明了其在中文医疗文本理解上的强大能力。数据增强的有效性:相较于RoBERTa基线,仅使用简化版DA(规则)就将Macro-F1提升了1.6个百分点。而使用我们完整的混合DA策略,提升幅度达到2.7个百分点(68.5% → 71.2%)。这表明,结合规则和模型生成的多样化增强样本,能有效缓解过拟合和数据稀缺问题,尤其对长尾类别帮助明显。对抗训练的有效性:仅使用FreeLB对抗训练,也带来了3.0个百分点的Macro-F1提升(68.5% → 71.5%),其效果与完整的DA策略相当。这说明通过对抗训练提升模型决策边界的平滑性,是增强泛化能力和鲁棒性的另一条有效途径。协同效应:本文提出的完整模型,融合了混合DA与FreeLB AT,取得了最佳性能。Macro-F1达到72.7%,相比基线提升了4.2个百分点,且所有指标的提升均具有统计显著性(p<0.01)。这证明了数据增强与对抗训练之间存在协同作用:DA提供了更丰富多样的“虚拟”训练数据,而AT则教会模型在这些数据和真实数据之间进行平滑泛化,共同抵御了过拟合和分布漂移。

5.2 鲁棒性评估

表2展示了基线模型RoBERTa-wwm-ext和我们的最佳模型在含有自然噪声的测试集上的表现。

表2:鲁棒性测试结果

模型 标准测试集 Macro-F1 (%) 噪声测试集 Macro-F1 (%) 性能下降率 (%)
RoBERTa-wwm-ext (基线) 68.5 61.8 9.78
Ours (DA + AT) 72.7 70.2 3.44

结果分析

基线模型在遇到错别字、缩写等自然噪声时,性能急剧下降了近10个百分点,验证了其鲁棒性不足的弱点。我们的模型在相同噪声环境下,性能下降仅为3.44个百分点,远低于基线。这得益于两个方面的努力:首先,数据增强策略中的“口语化与缩写注入”和EDA,使得模型在训练阶段就已经“见过”这类噪声;其次,对抗训练迫使模型对嵌入空间的微小扰动不敏感,而拼写错误和同义词替换在嵌入空间上往往表现为较小的向量变化。因此,二者结合极大地提升了模型在真实、嘈杂的临床环境下的稳定性。

5.3 消融实验

为了深入理解框架中各个组件的贡献,我们进行了一系列消融实验,结果如表3所示。

表3:消融实验结果(基于RoBERTa + DA + AT框架)

配置 Macro-F1 (%)
Full Model (DA+AT+FL) 72.7
w/o DA (仅AT+FL) 71.5
w/o AT (仅DA+FL) 71.2
w/o MixUp 71.9
w/o Back-Translation 71.7
w/o Focal Loss (用BCE) 71.0
w/o AT & DA (仅RoBERTa+FL) 68.9

结果分析

DA与AT的贡献:移除DA或移除AT都会导致性能明显下降,再次确认了二者是提升性能的核心支柱。不同DA模块的贡献:移除MixUp或回译,性能均有所下降(约0.8-1.0个百分点),说明它们各自贡献了独特的样本多样性。MixUp通过平滑标签空间,而回译通过引入跨语言的句式变化。损失函数的贡献:将Focal Loss换回标准的BCE,Macro-F1下降了1.7个百分点。这表明Focal Loss对于改善模型在长尾类别上的表现至关重要,验证了其对类别不平衡问题的有效性。综合基线:即使仅使用RoBERTa+FL,性能也高于原始RoBERTa基线,说明Focal Loss本身就有帮助。但远不如结合DA和AT的方案。

5.4 校准分析

我们计算了RoBERTa基线模型和我们的最佳模型的ECE(分10个区间)。

RoBERTa-wwm-ext (基线):ECE = 0.082Ours (DA + AT):ECE = 0.056

结果表明,我们的模型不仅预测得更准,其预测的概率也更加可靠。更低的ECE意味着,当模型预测一个样本有80%的概率属于某类别时,这个样本真实的可能性也确实更接近80%。这对于需要基于模型概率进行决策的临床应用(如风险分层)至关重要。标签平滑(在部分实验中)也进一步降低了ECE,但可能以牺牲少量准确率为代价。

5.5 可解释性分析

我们选取了部分预测错误的案例,通过可视化最后一层Transformer的Attention权重来探究模型决策依据。

假阳性案例:文本“患者主诉头痛、头晕两天”,模型错误预测了“高血压”。通过Attention发现,模型过度关注了“头晕”这个与高血压高度相关的症状,而忽略了文本中缺乏其他关键信息(如血压值)。假阴性案例:文本“T38.5,咽痛,咳嗽”,模型未能识别出“急性上呼吸道感染”。Attention显示模型关注了“咽痛”,但可能被“T38.5”这种缩写形式所迷惑。这提示我们,在数据增强中应加强对此类医学缩写的噪声模拟,或者在模型输入前增加一个标准化的预处理模块。


6. 讨论

6.1 研究发现

本研究系统地验证了数据增强与对抗学习相结合的框架在门诊EMR文本分类任务上的有效性。核心发现可总结为:

混合增强优于单一方法:结合低成本、可控的规则增强与高质量、多样的模型生成增强,能以较低的计算开销,显著扩充数据的多样性和覆盖率,是提升模型性能的基础。对抗训练是鲁棒性的关键:FreeLB等方法通过优化最坏情况下的表现,有效地平滑了模型的决策边界,使其对输入的微小变异(无论是真实噪声还是对抗攻击)具有更强的抵抗力。协同效应显著:DA和AT的结合并非简单的性能叠加,而是形成了互补。DA“传授”知识(告诉模型可能存在哪些变体),AT“锻炼能力”(让模型学会如何应对这些变体),二者共同作用,实现了1+1>2的效果。任务特定优化的重要性:针对长尾分布问题,Focal Loss等特定设计的损失函数是必不可少的。这提醒我们,在应用通用技术时,必须结合任务的具体特性进行适配。

6.2 局限性

尽管取得了积极成果,本研究仍存在一些局限性:

单中心数据:数据集仅来自一家医院,虽然内部划分考虑了时间漂移,但其泛化到不同地域、不同等级医院的能力仍需进一步验证。标签体系的局限性:我们将复杂的ICD编码简化为50个大类,虽然有利于模型训练,但损失了诊断的精细度。在未来,需要研究如何在更细粒度的标签体系(数百甚至数千个类别)上进行有效分类。计算成本:混合数据增强,特别是回译和MixUp,以及多步对抗训练,都引入了额外的计算开销,训练时间约为基线的2-3倍。如何在效果和效率间取得更好的平衡是一个待解决的问题。对医学知识的利用不足:当前模型主要从数据中学习统计规律,未能显式地融入大规模医学知识图谱(如UMLS、MeSH)。未来可探索如何将知识图谱与PLM结合,提升模型的可解释性和零样本/少样本学习能力。

6.3 未来工作

基于上述局限性,未来的工作可以从以下几个方向展开:

多中心联邦学习:联合多家医院,在数据不出院的前提下,使用联邦学习框架训练一个全局共享的模型,从而实现跨域泛化。层次化分类:构建一个从粗到细的层次化标签体系,先预测大的疾病类别,再在类别内部进行细分,以解决细粒度标签的稀疏性问题。轻量化与效率优化:探索知识蒸馏、模型剪枝等技术,将训练好的大模型压缩成轻量级版本,以便部署到资源受限的边缘设备(如医生的移动工作站)。融合多模态信息:门诊EMR通常还包含检验结果、影像报告等结构化或半结构化数据。研究如何有效地将文本信息与这些多模态信息融合,构建更全面的诊断支持系统。


7. 伦理与合规

医疗人工智能的研究与应用必须将伦理与合规置于首位。本研究在数据处理的每个环节都严格遵守了相关法律法规和伦理准则。

数据使用合规性:本研究获得了医院伦理委员会的正式批准(批件号:XXXXXX-YY)。所有数据的使用均在与医院签订的数据使用协议框架下进行,协议明确了数据仅用于本次学术研究,不得用于任何商业目的或与本研究无关的其他活动。隐私保护措施:我们遵循“最小必要原则”,仅收集了与研究直接相关的字段(主诉、现病史、初步诊断)。通过多层次的匿名化技术(详见4.1节),去除了所有直接和间接的个人身份信息(PHI/PII)。我们对匿名化后的数据进行了逆向工程评估,确认无法通过合理手段重新识别个人身份。算法公平性与偏见:我们意识到,训练数据中固有的偏见可能会被模型学习并放大。例如,某些科室的医生可能更倾向于使用某些诊断标签。我们通过使用Focal Loss和分层抽样来部分缓解标签不平衡问题,但算法偏见仍是一个需要持续监控和审计的领域。在未来的模型部署中,我们需要建立定期评估机制,检查模型对不同人群(如性别、年龄)是否存在不公平的预测。透明度与可解释性:我们选择部分预测案例进行Attention可视化,作为一种初步的可解释性尝试。我们认识到,黑箱模型在医疗领域的应用需要更高的透明度要求。未来将探索更先进的可解释性方法,并建立模型决策的解释报告生成模块,以辅助医生理解和信任模型的输出。


8. 结论

本研究针对门诊电子病历短文本分类任务面临的数据稀缺、噪声多、分布漂移和鲁棒性差等挑战,提出并实现了一个融合混合数据增强与对抗学习的深度学习框架。该框架以强大的中文预训练模型RoBERTa-wwm-ext为基础,通过设计多层次的规则与模型数据增强策略扩充了训练数据的有效性,并通过引入FreeLB对抗训练显著提升了模型的决策边界平滑性和鲁棒性。

在真实、脱敏的门诊EMR数据集上的大量实验表明,该方法在多个评估指标上均显著优于现有基线模型。特别是在Macro-F1这一关键指标上,相比RoBERTa基线提升了4.2个百分点。更重要的是,在模拟真实噪声环境的鲁棒性测试中,模型性能下降幅度远低于基线,证明了其在复杂临床场景下的优越稳定性和泛化能力。消融实验清晰地验证了框架中各个核心组件的独立贡献及其协同效应。

本研究的成果为构建高性能、高可靠的医疗文本智能分析系统提供了一套行之有效的技术方案。未来的工作将聚焦于多中心联合学习、细粒度标签分类、模型轻量化以及融合医学知识图谱等方面,以期将该技术推向更广阔的临床应用前景,最终赋能智慧医疗,造福医患。


附录:实现代码

以下代码片段使用PyTorch和Transformers库,展示了FreeLB对抗训练的核心逻辑。


# skeleton.py
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW

# --- 模型与数据准备 ---
model_name = "hfl/chinese-roberta-wwm-ext"
num_labels = 50  # 类别数
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
model.cuda()

optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# --- FreeLB 训练步骤 ---
def train_step_freelb(batch, eps=1e-5, adv_steps=5, adv_lr=1e-3):
    """
    执行一次包含FreeLB对抗训练的训练步骤。
    """
    # 1. 准备输入
    text_list = batch["text"]
    labels = torch.stack(batch["labels"]).cuda()  # 假设标签是list of tensors

    inputs = tokenizer(text_list, padding=True, truncation=True, max_length=128, return_tensors="pt")
    input_ids = inputs["input_ids"].cuda()
    attention_mask = inputs["attention_mask"].cuda()

    # 2. 获取初始词嵌入并创建扰动
    embeds_init = model.get_input_embeddings()(input_ids).detach()
    # 初始化扰动 delta,需要计算梯度
    delta = torch.zeros_like(embeds_init, requires_grad=True)
    
    # 3. 对抗扰动生成 (K步)
    for _ in range(adv_steps):
        # a) 前向传播 (使用被扰动的嵌入)
        outputs_adv = model(inputs_embeds=embeds_init + delta,
                            attention_mask=attention_mask,
                            labels=labels)
        loss_adv = outputs_adv.loss

        # b) 计算梯度并更新扰动 delta
        # 注意:此处的梯度只用于更新delta,不更新模型参数
        grad = torch.autograd.grad(loss_adv, delta)[0]
        delta.data = (delta + adv_lr * grad.sign()).clamp(-eps, eps)
        delta.data.detach_()
        delta.requires_grad_()

    # 4. 最终模型参数更新
    # a) 使用最终的对抗嵌入计算最终损失
    logits_final = model(inputs_embeds=embeds_init + delta,
                         attention_mask=attention_mask).logits
    
    # b) 结合原始损失和对抗损失 (可选,本例中仅用最终损失)
    # outputs_clean = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    # loss_clean = outputs_clean.loss
    # loss = 0.5 * loss_clean + 0.5 * F.binary_cross_entropy_with_logits(logits_final, labels.float())
    
    loss = F.binary_cross_entropy_with_logits(logits_final, labels.float())

    # c) 反向传播并优化模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

# --- 训练循环 (简化版) ---
# for epoch in range(epochs):
#     model.train()
#     for batch in train_dataloader:
#         loss = train_step_freelb(batch)
#         print(f"Loss: {loss:.4f}")
#     # ... validation and early stopping ...

(多标签版本,BCEWithLogits)
亮点:随机初始化 delta、对 pad 屏蔽、L2 归一化更新、内环步步累积梯度(更贴近 FreeLB)、AMP、梯度裁剪。
如果你是“单标签”,请看代码里“切换为单标签”的注释。


# freelb_train.py
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "hfl/chinese-roberta-wwm-ext"
num_labels = 50  # 多标签:多热向量;若单标签,请见下方注释
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels).to(device)

optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# 可选:训练总步数与 warmup
# total_steps = epochs * len(train_dataloader)
# scheduler = get_linear_schedule_with_warmup(optimizer, int(0.1 * total_steps), total_steps)

scaler = torch.cuda.amp.GradScaler()  # AMP

def _l2_norm(t, mask=None, eps=1e-12):
    if mask is not None:
        t = t * mask
    # 范数按样本聚合: [B, L, H] -> [B, 1, 1]
    norm = torch.norm(t.view(t.size(0), -1), dim=1, keepdim=True).view(-1, 1, 1)
    return torch.clamp(norm, min=eps)

def train_step_freelb(
    batch,
    eps=1e-5,          # 扰动半径(可适度加大,如 1e-3 ~ 1e-2,需配合归一化)
    adv_steps=5,       # 内环步数
    adv_lr=1e-2,       # 内环学习率(配合归一化,通常比你现在的 1e-3 稍大)
    max_grad_norm=1.0, # 梯度裁剪
    multilabel=True    # 若为单标签分类,置为 False 并使用 CrossEntropy
):
    model.train()

    # 1) 组装输入
    text_list = batch["text"]
    # 多标签:labels 为 float 的多热向量 [B, C]
    # 单标签:labels 为 long 的类别索引 [B]
    labels = batch["labels"]
    if isinstance(labels, list):
        labels = torch.stack(labels, dim=0)

    labels = labels.to(device)
    if multilabel:
        labels = labels.float()
    else:
        labels = labels.long()

    enc = tokenizer(
        text_list,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors="pt"
    ).to(device)

    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]  # [B, L]

    # 2) 初始嵌入 & 扰动(随机初始化到 [-eps, eps])
    with torch.no_grad():
        embeds_init = model.get_input_embeddings()(input_ids)  # [B, L, H]

    delta = torch.empty_like(embeds_init).uniform_(-eps, eps).to(device)
    delta.requires_grad_(True)

    # 构造 pad 掩码,形状 [B, L, 1]
    pad_mask = attention_mask.unsqueeze(-1).type_as(embeds_init)

    optimizer.zero_grad(set_to_none=True)

    # 3) 对抗内环(步步累积对模型参数的梯度,更贴近 FreeLB)
    for _ in range(adv_steps):
        with torch.cuda.amp.autocast():
            # 只对非 pad 位置施加扰动
            perturbed = embeds_init + delta * pad_mask
            out = model(inputs_embeds=perturbed, attention_mask=attention_mask)
            logits = out.logits

            if multilabel:
                loss_step = F.binary_cross_entropy_with_logits(logits, labels) / adv_steps
            else:
                loss_step = F.cross_entropy(logits, labels) / adv_steps  # 单标签切换到 CE

        scaler.scale(loss_step).backward(retain_graph=True)

        # 用 delta 的梯度来更新 delta(PGD 一步),再投影回 L2 球
        with torch.no_grad():
            # 取 delta 的梯度;AMP 下需要先 unscale
            grad_delta = delta.grad
            if grad_delta is None:
                # 理论上不会发生;保险起见
                grad_delta = torch.zeros_like(delta)

            # L2 归一化更新
            grad_norm = _l2_norm(grad_delta, mask=pad_mask)
            delta.add_(adv_lr * grad_delta / grad_norm)

            # 投影回 L2 球半径 eps
            delta_norm = _l2_norm(delta, mask=pad_mask)
            exceed_mask = (delta_norm > eps).float()
            # 当超界时,缩放到边界;否则保持原值
            delta.mul_((1 - exceed_mask) + exceed_mask * (eps / (delta_norm + 1e-12)))

            # 清理梯度,下一步再算
            delta.grad = None

    # 4) 最终一次前向 + 回传(可选:只做一个“收尾”loss)
    with torch.cuda.amp.autocast():
        perturbed = embeds_init + delta * pad_mask
        out = model(inputs_embeds=perturbed, attention_mask=attention_mask)
        logits = out.logits

        if multilabel:
            loss_final = F.binary_cross_entropy_with_logits(logits, labels)
        else:
            loss_final = F.cross_entropy(logits, labels)

    scaler.scale(loss_final).backward()

    # 梯度裁剪 & 更新
    clip_grad_norm_(model.parameters(), max_grad_norm)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)
    # if scheduler is not None:
    #     scheduler.step()

    # 返回可监控的指标
    return float(loss_final.detach().item())

附录B:超参数配置全表

参数
模型 hfl/chinese-roberta-wwm-ext
最大序列长度 128
批量大小 32
优化器 AdamW
学习率 2e-5
训练轮次 10 (早停)
权重衰减 0.01
数据增强
EDA (删除/交换/插入) 概率 0.1
同义词替换概率 0.15
MixUp α 0.4
对抗训练
方法 FreeLB
扰动步数 K 5
扰动幅值 ε 1e-5
对抗学习率 1e-3
损失函数
主要损失 BCE + Focal Loss
Focal Loss γ 2.0
BCE 与 FL 权重 1.0 : 1.0

基于数据增强与对抗学习的门诊电子病历(EMR)文本分类python编程

© 版权声明

相关文章

暂无评论

none
暂无评论...