DeepJ: Graph Convolutional Transformers with Differentiable Pooling for Patient Trajectory Modeling

作者:Deyi Li et.al.

论文链接:http://arxiv.org/abs/2506.15809

发布日期:2025-06-18

解读时间:2025-07-19 19:14:55

论文摘要

In recent years, graph learning has gained significant interest for modeling complex interactions among medical events in structured Electronic Health Record (EHR) data. However, existing graph-based approaches often work in a static manner, either restricting interactions within individual encounters or collapsing all historical encounters into a single snapshot. As a result, when it is necessary to identify meaningful groups of medical events spanning longitudinal encounters, existing methods are inadequate in modeling interactions cross encounters while accounting for temporal dependencies. To address this limitation, we introduce Deep Patient Journey (DeepJ), a novel graph convolutional transformer model with differentiable graph pooling to effectively capture intra-encounter and inter-encounter medical event interactions. DeepJ can identify groups of temporally and functionally related medical events, offering valuable insights into key event clusters pertinent to patient outcome prediction. DeepJ significantly outperformed five state-of-the-art baseline models while enhancing interpretability, demonstrating its potential for improved patient risk stratification.

AI解读

好的,我们来一起深入剖析这篇关于DeepJ的论文,我会尽力用清晰易懂的方式解释它的方法和技术细节。

1. 核心方法与创新点

这篇论文的核心方法是提出了一种叫做 Deep Patient Journey (DeepJ) 的新型模型,它结合了图卷积Transformer和可微图池化技术,用来建模患者电子病历(EHR)中医疗事件的复杂关系和时间演变。 简单来说,DeepJ就像一个“医疗事件关系挖掘机”,能够从病人的就医记录中发现隐藏的、有意义的医疗事件组合(比如疾病、药物、检查等),并利用这些信息来预测病人的健康结果(比如ICU死亡率、急性肾损伤等)。

主要创新点:

* 跨诊疗事件关系建模: 现有方法通常只关注单次就诊内部的医疗事件关系,或者把所有就诊记录简单地合并成一个快照。DeepJ 的一大亮点是能够同时捕捉单次就诊内部(intra-encounter)和多次就诊之间(inter-encounter)的医疗事件关系。 想象一下,病人一年前的糖尿病诊断可能会导致最近一次住院时开具胰岛素处方。DeepJ 能够识别并建模这种跨越时间的依赖关系。
* 图结构学习 (Graph Structure Learning, GSL): DeepJ 扩展了 Graph Convolutional Transformer (GCT) 架构,通过自注意力机制学习医疗事件之间的连接。 简单来说,GSL就像一个“关系发现引擎”,能够自动识别哪些医疗事件之间存在关联,并用图结构来表示这些关系。
* 临床模块发现 (Clinical Module Discovery, CMD): DeepJ 引入了一个基于可微图池化技术的层级聚类模块,用于识别具有临床意义的医疗事件子群(临床模块)。 临床模块就像一个个“医疗事件组合”,例如,一组与心脏衰竭相关的药物、检查和症状,或者一组与肺癌相关的治疗方案。
* 增强可解释性: DeepJ 能够为每个病人生成一个独特的临床轨迹图,展示关键的医疗事件和它们之间的关系。这有助于医生理解病人的病情发展过程,并做出更明智的决策。
* 性能提升: 在两个真实世界的数据集上,DeepJ 的预测性能显著优于其他五种最先进的基线模型。

与现有方法的区别:

| 特性 | DeepJ | 现有方法 |
| ---------------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
| 时间依赖性 | 考虑跨诊疗事件的时间依赖性 | 通常仅限于单次就诊或静态快照 |
| 图结构 | 学习个性化图结构 | 依赖于预定义的知识库或简单的共现关系 |
| 关系建模 | 同时建模 intra-encounter 和 inter-encounter 关系 | 通常只关注 intra-encounter 关系 |
| 临床模块发现 | 自动发现临床模块 | 通常使用全局池化,无法区分不同的临床状况 |
| 可解释性 | 提供个性化临床轨迹图 | 通常提供 population-level 的解释,缺乏个体针对性 |

2. 算法细节与流程

DeepJ 的算法流程可以概括为以下几个步骤:

1. 输入嵌入和时间编码:

* 每个医疗代码(如诊断代码、药物代码)都会被转换成一个向量表示(嵌入)。这就像把每个医疗代码“翻译”成一个数值向量,让模型能够理解它们的含义。
* 为了捕捉就诊之间的时间关系,DeepJ 使用时间编码技术将就诊的时间信息融入到医疗代码的嵌入中。

2. 图结构学习 (GSL):

* GSL 使用扩展的图卷积Transformer (EGCT) 模块来学习医疗代码之间的关系。每个 EGCT 模块都包含一个自注意力层和一个前馈神经网络。
* 自注意力机制: 自注意力机制允许模型关注不同医疗代码之间的相互作用,并为每个代码分配一个权重,表示它与其他代码的相关程度。
* 掩码 (Masking): 为了保证时间上的因果关系,DeepJ 使用掩码技术来防止未来的医疗代码影响过去的医疗代码。
* GSL 的输出是一个图,其中节点表示医疗代码,边表示代码之间的关系强度(注意力权重)。

3. 临床模块发现 (CMD):

* CMD 使用可微图池化技术将图结构聚类成若干个临床模块。每个模块代表一组相互关联的医疗代码。
* 图卷积网络 (GCN): CMD 使用 GCN 来聚合邻居节点的信息,从而增强节点表示。
* 可微池化: 可微池化允许模型学习如何将节点分配到不同的临床模块,同时保持图结构的可微性,以便进行端到端训练。

4. 临床模块加权和预测:

* DeepJ 使用一个简单的注意力机制来为每个临床模块分配一个权重,表示它对最终预测的重要性。
* 最后,加权的临床模块表示被输入到一个前馈神经网络分类器中,用于预测病人的健康结果。

算法流程图:

```
Input EHR sequence -> Embedding & Time Encoding -> GSL (EGCT blocks, attention weights) -> CMD (DiffPool blocks, cluster assignments) -> Clinical Module Weighting -> FFNN Classifier -> Outcome Prediction
```

算法的技术优势和创新之处:

* 端到端学习: DeepJ 能够进行端到端训练,这意味着模型的所有参数都可以通过梯度下降进行优化,从而提高预测性能。
* 可解释性: DeepJ 能够生成可解释的临床轨迹图,帮助医生理解病人的病情发展过程。
* 灵活性: DeepJ 可以应用于不同的临床预测任务和数据集。

3. 详细解读论文第三部分

第三部分着重描述了模型的具体实现细节和数学公式。 我们将分解关键公式,并解释其背后的意义:

* 输入嵌入和时间编码

* $$ Z = E + TE(V) \in \mathbb{R}^{SeqLen \times d_{model}} $$

* 公式解读: `Z` 是输入到 Transformer 第一层的数据。 `E` 是医疗代码经过嵌入层的表示,将每个医疗代码转换成一个 `d_model` 维的向量。 `TE(V)` 是时间编码矩阵,将就诊的时间信息加入到 `E` 中。这样做的好处是模型可以区分不同时间发生的相同医疗代码。
* 技术意义: 通过将代码嵌入和时间信息结合,模型能够同时理解代码的语义信息和时间上下文。
* `SeqLen` 是序列长度,表示一个病人所有就诊中医疗代码的总数(考虑padding)。

* $$TE(t_p, d) = \begin{cases}
sin(\frac{t_p}{t_{max}^{d/d_{model}}}), & \text{if } d \text{ is even} \\
cos(\frac{t_p}{t_{max}^{(d-1)/d_{model}}}), & \text{if } d \text{ is odd}
\end{cases}$$

* 公式解读: 这是时间编码 (TE) 的具体公式。 `t_p` 是当前 encounter 距离第一个 encounter 的时间差。 `t_max` 是数据集中最大的时间差。 `d` 是特征的索引,从 1 到 `d_model`。
* 技术意义: 这个公式使用正弦和余弦函数将时间信息编码成一个向量。 不同的频率用于不同的维度,使得模型可以区分不同时间尺度上的时间关系。 这种时间编码方法最初在 Transformer 模型中使用,因为它能够很好地表示时间关系,并且可以处理任意长度的序列。

* 图结构学习 (GSL)

* $$ \alpha_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d_{model}}} + \mathcal{M}_{i,j} $$

* 公式解读: 这是自注意力机制中计算注意力得分的公式。 `q_i` 是代码 `i` 的查询向量,`k_j` 是代码 `j` 的键向量。 `d_{model}` 是模型维度。 `\mathcal{M}_{i,j}` 是掩码,用于防止未来的代码影响过去的代码,以及防止padding 代码参与计算。
* 技术意义: 点积注意力是计算两个向量相似度的一种高效方法。 除以 `sqrt(d_model)` 是为了防止点积过大,导致 softmax 函数梯度消失。 掩码是保证模型因果性的关键。

* $$ \mathcal{M}_{i,j} = \begin{cases}
-\infty, & \text{if code } i \text{ happened before code } j \text{ or either is a padding code} \\
0, & \text{otherwise}
\end{cases} $$

* 公式解读: 这是掩码的具体定义。 如果代码 `i` 发生在代码 `j` 之前,或者代码 `i` 或 `j` 是填充代码,则掩码值为负无穷大,否则为 0。
* 技术意义: 这个掩码确保了模型只能关注过去的代码,从而保证了模型的因果性。同时,填充代码不会影响模型的学习。

* $$ w_{i,j} = \frac{exp(\alpha_{i,j})}{\sum_{j'} exp(\alpha_{i,j'})} $$

* 公式解读: 这是计算注意力权重的公式。 通过 softmax 函数将注意力得分归一化为概率分布。
* 技术意义: softmax 函数将注意力得分转换为权重,使得权重之和为 1。这样,模型可以根据权重来聚合不同代码的信息。

* $$ \tilde{e}_i = \sum_{j} w_{i,j} v_j $$

* 公式解读: 这是计算更新后的代码嵌入的公式。 `v_j` 是代码 `j` 的值向量。 `w_{i,j}` 是代码 `i` 和 `j` 之间的注意力权重。
* 技术意义: 这个公式将所有代码的值向量加权求和,权重由注意力机制决定。 这样,更新后的代码嵌入包含了其他代码的信息,从而增强了代码的表示能力。

* $$L_{KL} = \sum_{l=1}^{N} KL(A^{(l-1)} || A^{(l)}), \text{where } A^{(0)} = CO$$

* 公式解读: 这是 KL 散度损失函数,用于鼓励每一层 EGCT 的注意力权重矩阵与上一层的注意力权重矩阵相似。 第一层的注意力权重矩阵被初始化为代码共现矩阵 CO。
* 技术意义: 这个损失函数可以帮助模型学习到更稳定的图结构。代码共现矩阵可以提供关于代码之间关系的先验知识,而 KL 散度损失函数可以防止模型过度偏离这些先验知识。

* 临床模块发现 (CMD)

* $$H^{(l)} = GCN_{\theta_1}^{(l)}(A^{(l)}, X^{(l)}) \in \mathbb{R}^{N^{(l)} \times d_{model}}$$

* 公式解读: 使用图卷积网络 (GCN) 更新节点表示。 `A^{(l)}` 是第 `l` 层的邻接矩阵。 `X^{(l)}` 是第 `l` 层的节点特征。 `N^{(l)}` 是第 `l` 层的节点数。
* 技术意义: GCN 可以聚合邻居节点的信息,从而增强节点表示。

* $$S^{(l)} = softmax(GCN_{\theta_2}^{(l)}(A^{(l)}, X^{(l)})) \in \mathbb{R}^{N^{(l)} \times N^{(l+1)}}$$

* 公式解读: 使用图卷积网络 (GCN) 将节点分配到不同的簇。 `N^{(l+1)}` 是第 `l+1` 层的簇数。
* 技术意义: softmax 函数将 GCN 的输出转换为簇分配概率。

* $$X^{(l+1)} = S^{(l)T} H^{(l)} \in \mathbb{R}^{N^{(l+1)} \times d_{model}}$$

* 公式解读: 计算池化后的节点特征。
* 技术意义: 将每个簇的节点特征加权求和,权重由簇分配概率决定。

* $$A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)} \in \mathbb{R}^{N^{(l+1)} \times N^{(l+1)}}$$

* 公式解读: 计算池化后的邻接矩阵。
* 技术意义: 根据簇之间的连接关系更新邻接矩阵。

* $$L_{link} = \sum_{l=1}^{M} ||A^{(l)} - S^{(l)} S^{(l)T}||_F^2$$

* 公式解读: 链接预测损失函数,用于鼓励簇分配矩阵 S 能够重构邻接矩阵 A。
* 技术意义: 这个损失函数可以帮助模型学习到更好的簇分配。

* $$L_{entropy} = \sum_{l=1}^{M} \frac{1}{N^{(l)}} \sum_{i=1}^{N^{(l)}} Ent(S_i^{(l)})$$

* 公式解读: 熵正则化损失函数,用于鼓励簇分配更加明确。 `Ent(x)` 是熵函数。
* 技术意义: 这个损失函数可以防止节点被分配到多个簇,从而提高簇的质量。

* 临床模块加权和预测

* $$\alpha_q = softmax(X_q^{(M)} w)$$

* 公式解读: 计算每个簇的注意力权重。 `X_q^{(M)}` 是第 `M` 层的簇特征。 `w` 是可学习的参数。
* 技术意义: 注意力机制可以帮助模型选择最重要的簇,从而提高预测性能。

* $$G_{final} = \sum_{q=1}^{N^{(M)}} \alpha_q X_q^{(M)}$$

* 公式解读: 计算最终的图表示。
* 技术意义: 将所有簇的特征加权求和,权重由注意力机制决定。

* 总损失函数

* $$L_{total} = L_{outcome} + \lambda_{KL} L_{KL} + \lambda_{link} L_{link} + \lambda_{entropy} L_{entropy}$$

* 公式解读: 总损失函数是预测损失、KL 散度损失、链接预测损失和熵正则化损失的加权和。 `λ` 是超参数,用于控制不同损失函数的权重。
* 技术意义: 通过调整超参数,可以平衡不同损失函数的影响,从而优化模型的性能。

4. 实现细节与注意事项

* 图结构学习 (GSL):
* 实现细节: 使用 PyTorch 或 TensorFlow 等深度学习框架实现 EGCT 模块。 使用多头注意力机制来提高模型的表示能力。
* 实现难点: 训练 Transformer 模型需要大量的计算资源。 可以使用梯度累积或混合精度训练等技术来减少内存占用。
* 优化建议: 可以使用预训练的词向量来初始化代码嵌入。 可以使用学习率衰减或早停等技术来防止过拟合。
* 参数设置: `d_model` 通常设置为 128 或 256。 注意力头数通常设置为 8 或 16。 EGCT 模块数通常设置为 3 或 6。
* 临床模块发现 (CMD):
* 实现细节: 使用 PyTorch Geometric 或 DGL 等图神经网络库实现 GCN 和可微池化。
* 实现难点: 可微池化的计算复杂度较高。 可以使用图采样或图简化等技术来减少计算量。
* 优化建议: 可以使用谱聚类或 Louvain 算法等图聚类算法来初始化簇分配。
* 参数设置: GCN 层数通常设置为 2 或 3。 簇数需要根据数据集的特点进行调整。
* 损失函数权重:
* 实现细节: 可以使用网格搜索或贝叶斯优化等技术来调整损失函数权重。
* 注意事项: 损失函数权重需要根据数据集的特点进行调整。 如果 KL 散度损失权重过大,可能会导致模型无法学习到新的图结构。 如果链接预测损失权重过大,可能会导致簇分配过于依赖邻接矩阵。 如果熵正则化损失权重过大,可能会导致簇分配过于明确,失去灵活性。

总而言之, DeepJ 是一个复杂的模型, 需要仔细调整参数和优化训练过程才能达到最佳性能。 这篇论文提供了一个很好的起点, 可以帮助研究人员和开发人员构建自己的 EHR 分析模型。
返回论文列表