Abstract
In computational pathology, extracting and representing spatial features from gigapixel whole slide images (WSIs) are fundamental tasks, but due to their large size, WSIs are typically segmented into smaller tiles. A critical aspect of analyzing WSIs is how information across tiles is aggregated to predict outcomes such as patient prognosis. We introduce a model that combines a message-passing graph neural network (GNN) with a state space model (Mamba) to capture both local and global spatial relationships among the tiles in WSIs. The model’s effectiveness was demonstrated in predicting progression-free survival among patients with early-stage lung adenocarcinomas (LUAD). We compared the model with other state-of-the-art methods for tile-level information aggregation in WSIs, including statistics-based, multiple instance learning (MIL)-based, GNN-based, and GNN-transformer-based aggregation. Our model achieved the highest c-index (0.70) and has the largest number of parameters among comparison models yet maintained a short inference time. Additional experiments showed the impact of different types of node features and different tile sampling strategies on model performance. Code: https://github.com/rina-ding/gat-mamba.
Similar content being viewed by others
Introduction
Lung cancer results in more than 1.8 million deaths worldwide each year1. Non-small cell lung cancer (NSCLC), one of two lung cancer types, comprises around 85% of all lung malignancies in the United States2. Early stage (stage I, II by AJCC 8th edition) NSCLC patients are commonly treated with curative resection, but around 30–55% develop disease recurrence within the first five years of surgery3. This high recurrence rate reflects a need to identify early-stage NSCLC patients who may be at high risk of recurrence and give them personalized adjuvant therapies after the surgery. The current clinical standard for NSCLC prognosis and treatment planning is the tumor-node-metastasis (TNM) staging system that assesses tumor size, local invasion, and nodal and distance metastases. However, heterogeneous progression-free survival times are commonly observed among patients with identical TNM staging, suggesting this method is insufficient for risk stratification4.
Recently, several studies have shown the benefits of using quantitative histomorphologic features derived from Hematoxylin and Eosin (H&E)-stained WSIs to predict recurrence or survival in early-stage NSCLC5. Due to limitations on graphical processing unit (GPU) memory and the relatively large size of WSIs, WSIs are usually split into non-overlapping equal-sized tiles, and a tile aggregation method is needed to render a slide-level prediction. Most studies aggregate tile-level information using summary statistics such as mean to generate a slide or patient-level prediction67. Some studies utilize multiple instance learning (MIL)-based methods to aggregate the tiles by taking into account the importance of each tile to the prediction8,9. Other studies model the relationship between the tiles in the WSIs using local message-passing GNN-based approaches where each tile is a node and graph convolution operations are applied to capture the connectivity between nodes10,11,12. While promising, message-passing GNN-based approaches are limited by the aggregation operation, which only considers the node’s neighbors and fails to capture the global long-range dependency between the nodes. Several studies have attempted to leverage the global receptive field of transformers and combine them with the message-passing GNNs to capture both local and global relationships between the tiles13,14. Although transformers can capture the long-range dependencies among the nodes in the graph, handling their quadratic computational complexity associated with the self-attention mechanism is not practical, especially in applications that require large graphs like the ones in computational pathology.
A recent state space model, Mamba, has been shown to not only maintain the ability to capture long-range dependencies among the tokens in a sequence but also be more computationally efficient than the standard transformers15. Mamba achieves state-of-the-art results on multiple benchmarks related to long-range molecular graphs that, on average, have hundreds or thousands of nodes in each graph16,17. However, its potential in modeling large graphs from computational pathology has not been as widely explored.
In this work, we introduce an integrated message-passing GNN and state space modeling-based progression-free survival prediction pipeline in early-stage LUAD, the most common subtype of NSCLC. The contributions of work are as follows:
-
1.
We leverage an integrated message-passing GNN and state space model in computational pathology. GAT was used as the GNN, and Mamba was used as the state space model to capture both local and global tissue connectivity in the WSI.
-
2.
We systematically explored the effect of different node features and tile sampling strategies on model performance.
-
3.
Using patients from two publicly available datasets, we performed extensive experiments by comparing our method with baselines and conducting ablation studies to demonstrate the effectiveness of our proposed method.
The proposed GAT-Mamba pipeline, which consists of WSI tiling (a), node and edge feature extraction (b), graph construction (c), initializing the graph with the extracted node and edge features, and modeling on the graphs. BN: batch normalization. MLP: multi-layer perceptron. +: element-wise summation. N = 16 for positional encodings.
Related work
Survival analysis in computational pathology
Most works that utilize artificial intelligence for cancer prognosis can be divided into two major categories: hand-crafted features-based approaches and deep learning-based approaches. Hand-crafted features are developed using the domain knowledge of pathologists or oncologists18. A common approach is to extract quantitative features that describe the shape, texture, and geometric arrangement of all types of nuclei detected on the WSIs5. Various works have also focused on quantifying the density and spatial arrangement of tumor-infiltrating lymphocytes (TILs)6,19. After extracting the features from each tile, the features are usually aggregated using summary statistics and are fed into a simple linear Cox proportional hazards model to predict the prognosis. While relatively more interpretable and less computationally expensive, these hand-crafted features are usually targeted for a specific cancer or tissue type, limiting their broader utility.
Compared to hand-crafted features requiring feature engineering, deep learning-based approaches allow one to learn representation from the raw image data directly. Due to the computational complexity in training deep learning models and the lack of fine-grained region-level annotations, many works utilized attention MIL-based approaches. In the context of survival analysis using WSIs, a bag is a WSI, and each tile is an instance. If a WSI is from a high-risk patient, then at least one tile in the WSI must contain malignant tissues; if a WSI is from a low-risk patient, most or all the tiles must be benign or less malignant. Several works have trained the network to compute the attention score of each tile and aggregate the tile features using weighted pooling8,20.
Graph-based approaches in computational pathology
An emerging trend in computational pathology is to model a WSI as a graph and use message-passing GNNs to capture the spatial connectivity between different tissue regions. Given that these gigapixel WSIs are being divided into smaller tiles, graph representations allow spatial information across tiles to be captured and leveraged during model inference. Message-passing GNNs iteratively aggregate information from the neighboring nodes and update the current node information. GNNs can have different aggregate and update functions, and they are built to learn representations that reflect the topological structure of the graph data. For example, Chen et al.12 developed a graph convolutional network (GCN)-based survival analysis pipeline that models each WSI as an 8-nearest-neighbor graph based on the spatial coordinates of the tiles. Ding et al.10 used a graph isomorphism network (GIN) to predict molecular profiles from WSIs in colon cancer by constructing a graph from the WSIs based on a fixed Euclidean distance threshold derived from the tiles’ spatial coordinates. Wang et al.11 constructed a hierarchical graph from both cell-level and tile-level graphs and used a graph attention network (GAT) to predict progression-free survival in prostate cancer. Nevertheless, GNNs are prone to oversmoothing, and learned node representations become very similar across nodes after neighborhood information aggregation21. Moreover, neighborhood information aggregation is limited to local neighboring nodes, so these models cannot capture the global long-range relationships between the tissue regions. Combining GNNs with transformers to alleviate the over-smoothing issue from GNNs and better capture global node connectivity has yet to be fully explored. Zheng et al.13 devised a pathology image classification pipeline by first passing the WSI-constructed graph into a message-passing GNN and then passing the learned graph representation into a vision transformer. Sun et al.14 proposed a hybrid GAT-Transformer model where the WSI-constructed graphs were separately passed into a GAT and a transformer. However, transformers have quadratic computational complexity that is prone to overfitting, particularly in large graphs such as ones generated in digital pathology. Our model addresses these challenges by replacing the transformer with Mamba, which is a global feature learner with linear computational complexity.
In addition to model architecture, existing works have not systematically examined different graph construction methods and their influence on the model performance. Most of the graph-based works in computational pathology use either hand-crafted features11 or convolutional neural network (CNN)-extracted features based on either ImageNet transfer learning10,12,14 or self-supervised pretraining13 on relatively small-scale data. In this work, we examine the value of incorporating features extracted using recently published foundation models, comparing the model’s performance when using hand-crafted, CNN-derived, or foundation model-derived node features. Such exploration is important since the performance of a graph model depends on the quality of the input features. Besides the influence of node features on model performance, most existing works have not systematically examined the relationship between tile/node sampling strategy and model performance. Most works build graphs from the WSIs using all available tiles/nodes12,13,14, and some build graphs by filtering out tiles/nodes from benign or less malignant tissue regions10,11. We address this gap in the literature by conducting multiple experiments to examine the effect of node features and tile sampling strategies on model performance.
Long sequence modeling
Transformers have grown in popularity in both natural language processing and computer vision applications due to their self-attention mechanism that helps model long-range dependency in complex data. However, the transformer’s self-attention scales quadratically with respect to the number of tokens in the sequence. Attempts to devise various approximations of the attention mechanism using sparse attention or low-dimensional matrices have been made22,23. However, empirical observations have indicated that these approximations are not ideal for large-scale sequences because these approximations were developed at the expense of the very properties that make transformers effective15. Recently, a State Space Model (SSM) named Mamba15 was proposed to address the computational challenge associated with self-attention in transformers while maintaining superior performance. SSMs can generally be interpreted as a combination of recurrent neural networks (RNNs) and CNNs. To address the inability of SSMs to filter out irrelevant information when updating the sequence embedding, Mamba provides a mechanism for selective inputs (see Model Architecture for details). Recently, Mamba-based methods have been applied to digital pathology24,25. However, these works are based on the MIL framework, which does not explicitly capture any spatial relationship among the tissue regions like GAT-Mamba does.
Methods
Dataset
Two publicly available datasets, the National Lung Screening Trial (NLST, 132 patients) and the Cancer Genome Atlas (TCGA, 312 patients), were used. For both datasets, all selected patients were diagnosed with lung cancer and were surgically resected. The inclusion criteria include the patients being stage I or II LUAD, having at least one H&E-stained WSI, having progression status (either recurrence or lung cancer death), time to event for those who had progression, follow-up time for those who did not have progression, and for those who had progression, it occurred within five years of surgery. Cases with substantial artifacts, such as pen marks on the WSIs, were excluded. Table 1 summarizes each dataset.
Since both NLST and TCGA are publicly available data, additional approval was not required.
Graph construction
Figure 1 shows the GAT-Mamba workflow. Each WSI was modeled as a graph. The first step (Fig. 1a) is tiling the original WSI into non-overlapping tiles of either size 512 by 512 at 10x (1 micron per pixel (mpp)) or 1024 by 1024 at 20x (0.5 mpp) so that the total area covered by a tile is consistent across all patients which have different magnification levels available. Tiles with less than 20% tissue area were excluded.
The next steps involve initializing the graph with pre-defined node and edge features (Fig. 1b). There are two groups of node features. The first group is 1024 deep features extracted from a general-purpose self-supervised foundation model for pathology called UNI26. UNI was trained using more than 100,000 H&E-stained slides across 20 tissue types (e.g. heart, lung, kidney) from Massachusetts General Hospital and Brigham and Women’s Hospital that do not overlap with cases used in this work. As a result, UNI features are likely robust and pathology domain-specific, which can benefit our downstream analysis using lung pathology data. The second group is an N-dimensional sinusoidal positional encoding feature27 derived from the relative spatial coordinates of each tile/node within the WSI and N = 16 in this work. They are used as additional node features since they can explicitly encode the spatial location of each tile/node within the WSI. In LUAD, patients have one of five predominant histologic subtypes (lepidic, acinar, papillary, micropapillary, solid), each having different prognostic effects on patients. Lepidic is often associated with a better prognosis, acinar and papillary are associated with an intermediate prognosis, and micropapillary and solid are associated with a poorer prognosis28. To capture the heterogeneity of LUAD histologic subtypes and leverage that information in progression-free survival prediction, the subtype-subtype connection between two tiles was used as one of the edge features. For example, if an acinar tile is connected to a solid tile, then their subtype-subtype connection is “acinar-solid”. There are 21 combinations of such subtype-subtype connections among the five subtypes plus non-tumor. In other words, the subtype-subtype edge feature is a categorical feature that takes on 21 different values. The subtype of each tile was obtained using a pretrained ResNet18-based LUAD histologic subtype classifier29. Cosine similarity between the UNI-extracted deep features from each tile was used to quantify the subtype-subtype connection between the two nodes. Euclidean distance between two nodes was used to explicitly capture the spatial relationship between the nodes.
Similar to Chen et al.12 and Wu et al.30,31, each WSI is a graph with tiles being the nodes and the edge connectivity defined by the k-nearest neighbors approach where k = 8 and Euclidean distance between tiles was used as the distance metric (Fig. 1c). The assumption is that the immediate neighboring tiles provide context for each other and potentially share information.
Model architecture
The GAT branch
In this work, graph convolution operation was performed using GAT32, which uses an attention mechanism to learn the importance of each neighboring node to the current node. This mechanism focuses the network on the most relevant nodes to make predictions. Edge features were concatenated with the neighboring (j) and current (i) node features when computing the attention coefficients \(\alpha _{i, j}\):
where T represents transposition, \(\mathbin \Vert\) is concatenation, \({\varvec{a}}\) is the learnable parameter in the single-layer feedforward neural network that learns the attention coefficients, \({\varvec{W}}\) is a linear transformation shared across all node and edge features, \({\varvec{X}}\) is node features, \({\varvec{E}}\) is edge features, \({\mathcal {N}}_{i}\) is some neighboring node of current node i, and k represents all neighboring nodes j of the current node i. Once \(\alpha _{i, j}\) is obtained for all neighboring nodes, the current node’s features are updated by the weighted sum of its neighboring node features.
The Mamba branch
Similar to RNNs, SSMs map the input sequence \(x(t) \in {\mathbb {R}}^N\) to output sequence \(y(t) \in {\mathbb {R}}^N\) via a hidden state \(h(t) \in {\mathbb {R}}^N\) using a linear ordinary differential equation for continuous input:
where \({\varvec{A}}\) is a state matrix that compresses all the past information in the sequence, \({\varvec{B}}\) is the input matrix, and \({\varvec{C}}\) is the output matrix. Together, \({\varvec{A}}h(t)\) represents how the current state evolves over time, \({\varvec{B}}x(t)\) represents how the input affects the state, and \({\varvec{C}}h(t)\) represents how the current state translates to output. \({\varvec{A}}\) and \({\varvec{B}}\) are discretized using a step size \(\Delta\), such that \({\varvec{A}} = exp(\Delta {\varvec{A}})\), \({\varvec{B}} = (\Delta {\varvec{A}})^{-1}(exp(\Delta {\varvec{A}}) - I)\Delta B\). Both SSMs and its improved Structured State Space Sequence (S4) model33 can represent long sequences, but they are limited by their static representation that is not context-aware. That is, matrices \({\varvec{A}}\), \({\varvec{B}}\), and \({\varvec{C}}\) are always constant regardless of the input tokens x in a sequence. To address this issue, Mamba15 introduced a selection mechanism that allows the model to selectively retain information. Specifically, that is achieved by parameterizing \({\varvec{B}}\), \({\varvec{C}}\), and \(\Delta\) over the input x, where \({\varvec{B}}\) enables the model to control the influence of input \(x_t\) on the hidden state \(h_t\) and \({\varvec{C}}\) enables the model to control the influence of \(h_t\) on the output \(y_t\) based on the context. \(\Delta\) controls how much to focus on or ignore \(x_t\), and a larger \(\Delta\) means more focus is given to \(x_t\) as compared to the previous hidden states. In this work, each node in the graph is a token, and Mamba’s selection mechanism allows the model to minimize the influence of unimportant nodes at each step of hidden state computation. Similar to graph-based transformers, when modeling graph data using Mamba, positional information of the nodes needs to be explicitly encoded into the model since the graph connectivity information is lost when turning the graph into a sequence. In this work, sinusoidal positional encoding27 serves as the positional information for Mamba.
The GAT-Mamba pipeline
Suppose each patient has up to G WSIs, and each WSI is represented by a graph G. Each \(G_i\) has N nodes and E edges, and adjacency matrix \({\varvec{A}}\) \(\in\) \({\mathbb {R}}^{N \times N}\). There are two groups of node features, UNI node features \(\varvec{X_{UNI}}\) \(\in\) \({\mathbb {R}}^{N \times D_{UNI}}\), positional encoding node features \(\varvec{X_{PE}}\) \(\in\) \({\mathbb {R}}^{N \times D_{PE}}\). First, a linear layer \(L_{node}\) was applied to transform \(\varvec{X_{UNI}}\) (Algorithm 1 line 7). Then \(\varvec{{\hat{X}}_{UNI}}\) was concatenated with \(\varvec{X_{PE}}\) to form the final node features \({\varvec{X}}\) (Algorithm 1 line 8). There are two types of edge features: categorical edge features (subtype-subtype connection) \(\varvec{E_{cat}}\) \(\in\) \({\mathbb {R}}^{E \times D_{cat}}\) and continuous edge features \(\varvec{E_{cont}}\) \(\in\) \({\mathbb {R}}^{E \times D_{cont}}\). First, an edge embedding function (\(EB_{cat}\)) was used to transform the categorical features \(\varvec{E_{cat}}\) into continuous features \(\varvec{{\hat{E}}_{cat}}\) and a linear layer (\(L_{edge}\)) was used to transform \(\varvec{E_{cont}}\) into continuous features that share the same dimension as \(\varvec{{\hat{E}}_{cat}}\) (Algorithm 1 line 9). The final edge features \({\varvec{E}}\) were formed by summing these two groups of transformed edge features element-wise (Algorithm 1 line 10).
The final node features \({\varvec{X}}\) were passed into the Mamba branch, and both \({\varvec{X}}\) and \({\varvec{E}}\) were passed into the GAT branch. The output of the GAT branch was summed with the Mamba branch element-wise (Algorithm 2 line 8), and the resulting representation was passed to a global mean pooling layer to generate a patient-level embedding (Algorithm 1 line 14) before using an MLP layer to predict the risk score (Algorithm 1 line 16). The negative logarithm of Cox partial likelihood loss from DeepSurv34 was used as the training objective. Cox loss is a ranking loss that penalizes predicted risk scores that are not in concordance with the patients’ time to event. Algorithm 1 shows each step in the pipeline.
Implementation details
Five-fold cross-validation, stratified by progression-free survival status, was used. The train-validation-test split for each fold was done at the patient level by combining NLST and TCGA patients during model training (60%, 266 patients), validation (20%, 89 patients), and testing (20%, 89 patients).
The model had one GATMambaBlock, 64 hidden dimensions for the UNI node features \(\varvec{X_{UNI}}\), 16 hidden dimensions for the positional encoding features \(\varvec{X_{PE}}\), and 16 hidden dimensions for the edge features \({\varvec{E}}\). The model was trained using batch size 16, learning rate 0.00005, weight decay 0.0001, and dropout 0.3 with Adam optimizer. Early stopping with a tolerance of 5 epochs and a maximum of 200 epochs was used to monitor the validation loss. All models in this work were implemented using Torch Geometric35 and PyTorch 2.0 on NVIDIA-RTX-8000 GPUs. The code of GATMambaBlock was based on the one from16.
Evaluation metrics
All comparison and ablation experiments use the concordance index (C-index) on the test sets as the primary metric36. Briefly, C-index calculates the proportion of patients whose predicted risks and progression-free survival times are concordant among all uncensored patients. The secondary metric was the dynamic area under the receiver operating characteristic curve (AUC), which measures the model’s ability to distinguish patients who experience events at different time points (1, 3, and 5 years) and those who do not37. A paired t-test with multiple comparison correction38 was used as the statistical test to compare model performance. The inference time of each model was calculated by running the model in inference mode using the first fold’s test set.
Experiments and results
Comparison with baseline models
The effectiveness of our GAT-Mamba model was compared against six baseline models encompassing different state-of-the-art methods to aggregate tile-level features to make a WSI/patient-level prediction. To make a fair comparison, UNI features were used in all models, and the models’ hyperparameters were tuned to our dataset with the parameters recommended by the original works as a reference.
-
(1)
Clinical: A baseline clinical Cox model39 that includes age, gender, race, and overall pathological stage as variables. Smoking data was not included since 51% of patients were missing this information.
-
(2)
MLP: UNI features \(\varvec{X_{UNI}}\) extracted from each tile and summary statistics, including mean, standard deviation, and range, were used to aggregate tile-level features into patient-level features, which were then fed into a two-layer MLP optimized by Cox loss.
-
(3)
AttentionMIL: A MIL model that aggregates the tile-level UNI features using attention scores learned by the network8.
-
(4)
TransMIL: A transformer-based MIL model that uses the Nyström method to approximate the importance of each tile during the tile-level feature aggregation9.
-
(5)
MambaMIL: A Mamba-based MIL model that uses Mamba to calculate the importance of each tile during the tile-level feature aggregation25.
-
(6)
SCMIL: A self-attention/transformer-based MIL method that first uses a learnable module to filter out task-irrelevant tiles and then uses sparse self-attention to learn the interactions among tiles40.
-
(7)
WiKG: A graph-based approach that dynamically constructs the edge connections in the graph based on dot product similarity between the learned tile/node features41.
-
(8)
PatchGCN: A graph-based approach that uses DeepGCN42 as the GNN and builds graphs from WSIs using the 8-nearest neighbor approach to predict patient survival.
-
(9)
GPT: A graph transformer model13 that first passes the WSI-built graph into a GCN43 followed by a vision transformer to classify WSIs.
As shown in Table 2 and Fig. 2, GAT-Mamba outperformed all six baseline models regarding both C-index and dynamic AUC (p < 0.05 between GAT-Mamba and all baseline models except for MambaMIL). Figure 3 shows that GAT-Mamba had the lowest log-rank p-value in Kaplan-Meier analysis, indicating it had the best ability to separate low-risk and high-risk patients as compared to other models. The clinical Cox model was the worst-performing model, resulting in an average C-index of 0.608. PatchGCN was better than the clinical model but worse than the other GNN-based baseline (WiKG) and the rest of the compared models. GTP achieved a slightly better average C-index and had a smaller standard deviation than TransMIL. GTP, TransMIL, and MambaMIL performed better than MLP and AttentionMIL. SCMIL was the best-performing baseline but was also the most computationally expensive one as indicated by the number of parameters, training time, and inference time. In addition, as reflected in Fig. 2, GAT-Mamba’s performance distribution is relatively less spread out across folds than most baseline models, especially when using dynamic AUCs as the evaluation metric.
Kaplan-Meier curves of all baseline models and GAT-Mamba. For each fold in five-fold CV, test set patients were stratified into high-risk and low-risk groups by the median risk score in the training set of that fold. The curves were generated from the combined test set patients across the 5 folds. P-values from log-rank test are shown in each plot.
Ablation studies
The effectiveness of each component of GAT-Mamba was assessed by doing ablation experiments. Specifically, five ablated models were trained and evaluated.
-
(1)
GAT: This model only has the GAT branch.
-
(2)
Mamba: This model only has the Mamba branch.
-
(3)
GAT-Transformer: This model replaces the Mamba branch with a standard transformer.
-
(4)
GAT-MambaNo\({\varvec{E}}\): This model has no edge features as input for the GAT branch.
-
(5)
GAT-MambaNo\(\varvec{X_{PE}}\): This model does not have positional encoding node features \(\varvec{X_{PE}}\).
Table 3 shows that all the ablated models performed worse than the complete model. In particular, the complete model was statistically significantly better than the GAT and GAT-Transformer models (p < 0.05).
In addition to modeling ablation, node feature ablation was also performed to assess the impact of different node features on the model performance.
-
(1)
Hand-crafted: Studies have found the prognostic values of lymphocytes in the tumor microenvironment of lung cancer6,44. In this work, the publicly available HoVer-Net model pretrained on PanNuke dataset was first used for cancer nuclei and lymphocyte detection45. The detection results were manually verified by a pathologist (E.R.). A total of 57 features related to the density of lymphocytes and the spatial co-localization between lymphocytes and cancer nuclei were extracted from each tile from the WSI6,46.
-
(2)
LUADDeep: Different histologic subtypes of LUAD have been shown to be associated with patient prognosis28. 512 deep features were extracted from a ResNet18-based pretrained LUAD histologic subtype classifier29.
-
(3)
ResNet50IN: 1024 deep features were extracted from ResNet50 pretrained on ImageNet data47.
-
(4)
CONCH: 512 deep features were extracted from CONCH, a vision language pathology foundation model pretrained on 1.17M image caption pairs48.
-
(5)
PLIP: 512 deep features were extracted from PLIP, another vision-language pathology foundation model pretrained on over 200k image-text pairs from pathology Twitter49.
-
(6)
UNI: 1024 deep features were extracted from UNI, a large vision model trained on over 100M pathology images26.
Table 4 shows that GAT-Mamba achieved the best performance when using UNI node features as compared to the other types of features. The pathology foundation model-extracted features (CONCH, PLIP, UNI) and ResNet50IN resulted in better-performing models than the other two feature types, indicating that, in general, features extracted from any models that were trained on large datasets (either natural images or pathology images) would be relatively more robust than the ones trained on smaller datasets.
Results of tile sampling experiments. (a) represents the line graphs visualizing the average C-index and its standard deviation across different percentages of tiles sampled or when using only all (100%) aggressive (micro-papillary and solid) or all (100%) non-aggressive (non-tumor and lepidic) tiles, using UNI node features. (b) represents a bar graph showing the macro-average of the average C-index across 5, 10, 20, 30, 60, and 100 percent sampling for all six types of node features, and (c) represents the range of the average C-index (the difference between lowest and highest C-index values across all sampling methods for a given feature extraction approach).
Different tile sampling strategies
Since WSIs are large and can be broken down into hundreds or thousands of tiles, exploring the most computationally efficient way to construct and minimize the number of tiles/nodes used while maintaining the model performance is important. In this work, experiments were done to assess the impact of different tile sampling strategies on the model performance and how the impact would change when using different node features. Specifically, the first kind of tile sampling is based on the aggressiveness of each tile: (1) sampling only the tiles of the most aggressive subtypes (micropapillary and solid) and (2) sampling only the tiles of the least aggressive subtypes (non-tumor and lepidic). The second kind of tile sampling was based on the percentage of tiles (5–100 percent), and tiles were randomly selected. These tile sampling strategies were applied to all six groups of node features, respectively.
Figure 4a shows the trend of model performance under different sampling strategies when using UNI features as GAT-Mamba node features. Generally speaking, across all node feature types, using 100% tiles resulted in the best-performing models. However, the amount of performance improvement depends on the type of node features used. According to Fig. 4b and c, generally, when the features are more predictive (better macro-averaged C-index), the model performance will improve less when using more tiles to build the graph. For example, Hand-crafted features were the least predictive since they resulted in the lowest average C-index (0.623), and the C-index difference between the best and worst sampling percentages was the largest, meaning this model benefited the most from increasing the number of tiles. However, the model with UNI features did not benefit much from increasing the tiles.
In addition, across all node feature types, sampling only tiles of the least aggressive subtypes or the most aggressive subtypes generally resulted in models that have slightly worse performance than any of the random sampling methods (5 to 100 percent tiles sampling), which implicitly includes tiles of all subtypes. This indicates that including all types of tissue regions in the analysis is important.
Model prediction visualization
For each fold in the five-fold cross-validation, the median risk score from the training set was applied to the test set to get low-risk and high-risk patients. After all test set patients’ risk categories were obtained, we examined the characteristics of the two groups. In total, there were 232 high-risk and 212 low-risk patients. Figure 5a shows that the two groups had significantly different progression-free survival distributions. Figure 5b, c indicates that the high-risk patients generally have a more solid subtype than the low-risk patients, and Fig. 5b also shows that the low-risk patients tend to have lepidic as the predominant subtype. Figure 5d indicates that low-risk patients have a higher lymphocyte density as compared to high-risk patients, and Fig. 5e shows that there is more tumor-lymphocyte co-localization in low-risk patients than in high-risk patients.
Visualization of characteristics of GAT-Mamba predicted risk groups. (a) represents the Kaplan-Meier curves of low and high-risk groups where the log-rank test p-value indicates a statistically significant difference in the progression-free survival distribution of the two groups. (b) represents the overall distribution of predominant histologic subtypes in low (left) and high (right) risk patients. (c) represents the distribution of the percentage of solid tiles in each patient in low and high-risk groups. (d)(e) shows the distribution of four hand-crafted features between low and high-risk groups, with (d) representing the number of lymphocytes divided by the number of all nuclei, (e) representing the TIL abundance score46.
Characteristics of tiles for model-predicted false negative (FN), true positive (TP), and false positive (FP) patients. (a) shows the distribution of the percentage of non-tumor tiles. (b) shows the distribution of the percentage of solid tiles. (c) shows the distribution of the total number of tiles. **** p < 0.001, ns: p \(\ge\) 0.05 using Wilcoxon rank sum test.
Error analysis
The purpose of error analysis in this work is to understand the potential reasons why the model makes a wrong prediction. From a clinical standpoint, false negative predictions are not ideal for prognosis since false negatives mean that a patient potentially would miss opportunities for treatment intervention after underestimating the risk of progression. A true positive is when the patient develops a progression event within 5 years of surgery and has a predicted risk score that falls in the high-risk category. There were a total of 61 (13.7%) false negatives and 115 (25.9%) false positives. We found that for false negatives, they had nearly twice as many non-tumor tiles as compared to the true positive patients (Fig. 6a) and had much less solid tissue than the true positives (Fig. 6b). We further discovered that, on average, the false negatives had larger tissues as compared to the true positives (Fig. 6c). Together, these trends indicate that no matter the overall size of the tissue, they have to contain a decent amount of malignant tumor regions for the model to predict prognosis correctly. There is no obvious trend for false positives.
Discussion
Due to their gigapixel size, WSIs are typically divided into smaller tiles. Optimizing how tiles are represented and aggregated directly impacts a model’s ability to utilize this information for WSI-level predictions. In this work, we proposed a novel pipeline that integrates GAT, a message-passing GNN, with Mamba, a state space model, to capture local and global spatial relationships among the tiles in the WSIs to predict progression-free survival in LUAD patients.
Our proposed GAT-Mamba outperformed all baseline models and ablations that use different ways to aggregate tile-level features to render a WSI-level prediction and a clinical feature-based model. Experiment results show that the clinical model achieved a C-index of 0.608, but it was inferior to all other models that incorporated WSI-derived features, indicating the potential of adding quantitative image features in improving prognosis. Among the four non-graph-based baseline models (MLP, AttentionMIL, TransMIL, MambaMIL, and SCMIL), MLP and AttentionMIL had similar performance, and TransMIL, MambaMIL, and SCMIL had the best performance among these models. This indicates that TransMIL’s and SCMIL’s self-attention mechanism and MambaMIL’s long-range dependency modeling technique might have helped capture the long-range correlation information between different tissue regions, and such correlation contributed to improved performance, whereas MLP and AttentionMIL did not leverage such correlation since they treated each tile independently. The reason why all MIL-based methods do not outperform GAT-Mamba is likely because they do not take into account tissue regions’ spatial context, which provides prognostic value in various cancer types19,50. In addition, as indicated in Table 2, the inference time of all MIL methods is considerably longer than the other non-MIL methods. As for message-passing graph-based baselines/ablations (GAT, PatchGCN, and WiKG), GAT performed much better than PatchGCN, indicating that the GAT node aggregation method might be more effective for this work. WiKG outperformed PathGCN but not GAT. Since the dynamically-learned edge connections in WiKG are solely based on the node/edge feature similarity instead of spatial locations, spatial information was not encoded in the learned representation, which led to suboptimal performance as compared to GAT which models spatial information. The two models that leveraged a combined GNN and transformer achieved very similar results, but they did not outperform GAT alone. In terms of computational efficiency, our approach had the highest representational capacity, evidenced by the number of parameters in the model, yet achieved similar inference times compared to other methods, except for attention MIL and TransMIL, which have much longer inference times than the rest. Although GAT-Mamba has more parameters than most other models, its performance and inference time are still advantageous. By considering the trade-offs between model size, inference time, and accuracy, GAT-Mamba is still the most desirable, especially in a high-stakes field like medicine, where accuracy is important.
GAT-Mamba outperformed the two GNN-transformer-based models because of the input selection mechanism in Mamba. In relatively large graphs like the ones used in this study (691 nodes on average), it is important to filter out noisy nodes that do not contribute valuable information to the prognosis. Instead of compressing all the information like transformers do, Mamba only selectively retains the most informative nodes. The fact that Mamba works better than transformers on pathology graphs also resonates with the findings that Mamba performed better on large long-range molecular data than the transformer-based approaches16. As indicated by the improved performance of the model after combining both GAT and Mamba components in the ablation studies, GAT-Mamba benefited from both local (GAT) and global (Mamba) feature learners that retained the most informative nodes based on both local and global contexts in the WSI. The benefit of using local and global feature learners in a parallel manner is also evident when comparing the performance of GAT-Mamba to the one of GTP13, a model that uses a local feature learner first and then the global feature learner in a sequential manner. The sequential structure in GTP might have imposed a stronger bias on the benefit of local information and therefore limited the model’s ability to fully capture the global information.
Besides model architecture, different node features were also compared in this study. Unlike CNNs which take in raw images as input, graph models take in user-provided node (and edge) features as input, indicating that the quality of input node features has a direct impact on the quality of representation learned by the graph models. To date, ResNet50IN-extracted features are one of the most common types of node features in digital pathology graph modeling10,12,14. Our node feature ablation study shows that pathology foundation model-extracted features PLIP, CONCH, and UNI resulted in better model performance than ImageNet-pretrained-model-extracted features or LUAD domain-specific features. This suggests one can leverage pathology foundation models as feature extractors in their work. UNI features resulted in the best-performing model compared to PLIP and CONCH, which might be because UNI was trained on the most amount of data. In addition, several related works build graphs from the WSIs using all available tiles/nodes12,13,14. Results from our tile-sampling experiments suggest that one may not need to use all the tiles from the WSI to construct the tile-level graph when robust node features are used. Since WSIs are large and can be broken down into hundreds or thousands of tiles, good node features might help save storage and computational costs associated with bigger graphs.
Although UNI features resulted in the best model performance, they are not as interpretable as other worse-performing features, such as hand-crafted features and histologic subtypes. One way to enhance interpretability is to leverage the more interpretable features such as LUAD subtypes and immune infiltration in a post-hoc analysis of the model prediction. We found that the model-predicted low-risk patients tend to have a lepidic predominant subtype, whereas the high-risk patients tend to have a solid predominant subtype. This is consistent with the fact that lepidic is the least aggressive subtype, whereas solid is the most aggressive one28. In addition, more immune cells and greater immune-tumor cell co-localization appear in low-risk patients as compared to the high-risk ones, which aligns with our current understanding of lung cancer where immune infiltration is associated with favorable disease outcome6,44,51. Finally, error analysis indicates that the inherent limitation associated with the tissue sampling (a large portion of the benign region) during the surgical resection might have been attributed to some false negative predictions. Although there are still some more malignant regions, they might be insufficient for the model to recognize the aggressiveness of the entire tumor. In general, for prognostic tasks, false positives and false negatives can have important implications: false positives may lead to overtreatment, while false negatives may prevent someone with aggressive disease from getting beneficial treatment intensification.
Although our proposed GAT-Mamba pipeline showed promising results, we note several areas for improvement. For example, results from error analysis indicate that one can potentially improve the model performance by explicitly adding more importance to tiles with more aggressive subtypes (e.g., solid) while reducing the importance of tiles with less aggressive subtypes (e.g., non-tumor). Moreover, although UNI node features were shown to be the best among all types of node features compared in this study, further performance improvements could be achieved by fine-tuning UNI with additional disease-specific data. In addition, while edge features related to LUAD subtype-subtype connection may not directly apply to other cancer types, the approach of leveraging the characteristics of each tile (such as different subtypes) as domain-specific edge features can be generalized. Also, the model’s generalizability could be improved if the training dataset could include more diversity in tissue staining, tissue preparation methods (e.g., flash-frozen or formalin-fixed paraffin embedded tissue), and patient characteristics. Finally, prognostic models could be improved by incorporating more than just WSI-based information, such as combining both pathology and radiology data, providing a more comprehensive view of the patient’s prognosis.
Conclusion
In summary, we demonstrated the benefit of combining a message-passing GNN and a state space model Mamba to capture both local and global spatial relationships between different regions in the WSIs. The model outperformed all baselines for predicting progression-free survival in early-stage LUAD patients. Experiments show the impact of different types of node features and different tile sampling strategies on model performance. Although the prediction task of this work is on LUAD prognosis, one can easily use the pipeline on any other digital pathology WSI-based tasks, such as classification.
Data availability
The two datasets used in this study are publicly available at https://wiki.cancerimagingarchive.net/display/NLST/NLST+Pathology for NLST and https://portal.gdc.cancer.gov/projects/TCGA-LUAD for TCGA.
References
Sung, H. et al. Global cancer statistics 2020: Globocan estimates of incidence and mortality worldwide for 36 cancers in 185 countries. CA Cancer J. Clinic. 71, 209–249 (2021).
Molina, J. R., Yang, P., Cassivi, S. D., Schild, S. E. & Adjei, A. A. Non-small cell lung cancer: Epidemiology, risk factors, treatment, and survivorship. Mayo Clin. Proc. 83, 584–594 (2008).
Uramoto, H. & Tanaka, F. Recurrence after surgery in patients with nsclc. Trans. Lung Cancer Res. 3, 242 (2014).
Woodard, G. A., Jones, K. D. & Jablons, D. M. Lung cancer staging and prognosis. Lung Cancer Treatment Res. 47–75 (2016).
Wang, S. et al. Computational staining of pathology images to study the tumor microenvironment in lung cancer. Cancer Res. 80, 2056–2066 (2020).
Ding, R. et al. Image analysis reveals molecularly distinct patterns of tils in nsclc associated with treatment outcome. NPJ Precision Oncol. 6, 1–15 (2022).
Wang, X. et al. A prognostic and predictive computational pathology image signature for added benefit of adjuvant chemotherapy in early stage non-small-cell lung cancer. EBioMedicine 69, 103481 (2021).
Ilse, M., Tomczak, J. & Welling, M. Attention-based deep multiple instance learning. In: International conference on machine learning, 2127–2136 (PMLR, 2018).
Shao, Z. et al. Transmil: Transformer based correlated multiple instance learning for whole slide image classification. Adv. Neural Inform. Process. Syst. 34, 2136–2147 (2021).
Ding, K., Zhou, M., Wang, H., Zhang, S. & Metaxas, D. N. Spatially aware graph neural networks and cross-level molecular profile prediction in colon cancer histopathology: a retrospective multi-cohort study. Lancet Digital Health 4, e787–e795 (2022).
Wang, Z. et al. Hierarchical graph pathomic network for progression free survival prediction. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part VIII 24, 227–237 (Springer, 2021).
Chen, R. J. et al. Whole slide images are 2d point clouds: Context-aware survival prediction using patch-based graph convolutional networks. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part VIII 24, 339–349 (Springer, 2021).
Zheng, Y. et al. A graph-transformer for whole slide image classification. IEEE Trans. Med. Imaging 41, 3003–3015 (2022).
Sun, X. et al. Tgmil: A hybrid multi-instance learning model based on the transformer and the graph attention network for whole-slide images classification of renal cell carcinoma. Comput. Methods Programs Biomed. 242, 107789 (2023).
Gu, A. & Dao, T. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752 (2023).
Wang, C., Tsepa, O., Ma, J. & Wang, B. Graph-mamba: Towards long-range graph sequence modeling with selective state spaces. arXiv preprint arXiv:2402.00789 (2024).
Dwivedi, V. P. et al. Long range graph benchmark. Adv. Neural Inform. Processing Syst. 35, 22326–22340 (2022).
Bera, K., Schalper, K. A., Rimm, D. L., Velcheti, V. & Madabhushi, A. Artificial intelligence in digital pathology-new tools for diagnosis and precision oncology. Nat. Rev. Clin. Oncol. 16, 703–715 (2019).
AbdulJabbar, K. et al. Geospatial immune variability illuminates differential evolution of lung adenocarcinoma. Nat. Med. 26, 1054–1062 (2020).
Yao, J., Zhu, X., Jonnagaddala, J., Hawkins, N. & Huang, J. Whole slide images based cancer survival prediction using attention guided deep multiple instance learning networks. Med. Image Analy. 65, 101789 (2020).
Wu, X., Ajorlou, A., Wu, Z. & Jadbabaie, A. Demystifying oversmoothing in attention-based graph neural networks. Adv. Neural Inform. Process. Syst. 36 (2024).
Zaheer, M. et al. Big bird: Transformers for longer sequences. Adv. Neural Inform. Process. Syst. 33, 17283–17297 (2020).
Choromanski, K. et al. Rethinking attention with performers. arXiv preprint arXiv:2009.14794 (2020).
Fang, Z. et al. Mammil: Multiple instance learning for whole slide images with state space models. arXiv preprint arXiv:2403.05160 (2024).
Yang, S., Wang, Y. & Chen, H. Mambamil: Enhancing long sequence modeling with sequence reordering in computational pathology. arXiv preprint arXiv:2403.06800 (2024).
Chen, R. J. et al. Towards a general-purpose foundation model for computational pathology. Nat. Med. 30, 850–862 (2024).
Dosovitskiy, A. et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929 (2020).
Kuhn, E. et al. Adenocarcinoma classification: Pattern and prognosis. Pathologica 110, 5–11 (2018).
Ding, R., Yadav, A., Rodriguez, E., da Silva, A. C. A. L. & Hsu, W. Tailoring pretext tasks to improve self-supervised learning in histopathologic subtype classification of lung adenocarcinomas. Comput. Biol. Med. 166, 107484 (2023).
Wu, J. et al. A precision diagnostic framework of renal cell carcinoma on whole-slide images using deep learning. In: 2021 IEEE International Conference on Bioinformatics and Biomedicine (BIBM), 2104–2111 (IEEE, 2021).
Wu, J. et al. A personalized diagnostic generation framework based on multi-source heterogeneous data. In: 2021 IEEE International Conference on Bioinformatics and Biomedicine (BIBM), 2096–2103 (IEEE, 2021).
Velickovic, P. et al. Graph attention networks. stat 1050, 10–48550 (2017).
Gu, A., Goel, K. & Ré, C. Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396 (2021).
Katzman, J. L. et al. Deepsurv: personalized treatment recommender system using a cox proportional hazards deep neural network. BMC Med. Res. Methodol. 18, 1–12 (2018).
Fey, M. & Lenssen, J. E. Fast graph representation learning with pytorch geometric. arXiv preprint arXiv:1903.02428 (2019).
Harrell, F. E. Jr., Lee, K. L. & Mark, D. B. Multivariable prognostic models: issues in developing models, evaluating assumptions and adequacy, and measuring and reducing errors. Stat. Med. 15, 361–387 (1996).
Hung, H. & Chiang, C.-T. Estimation methods for time-dependent auc models with survival data. Can. J. Stat. 38, 8–26 (2010).
Benjamini, Y. & Hochberg, Y. Controlling the false discovery rate: a practical and powerful approach to multiple testing. J. Royal Stat. Soc. Series B (Methodological) 57, 289–300 (1995).
Fox, J. & Weisberg, S. Cox proportional-hazards regression for survival data. An R and S-PLUS companion to applied regression 2002 (2002).
Yang, Z., Liu, H. & Wang, X. Scmil: Sparse context-aware multiple instance learning for predicting cancer survival probability distribution in whole slide images. In: International Conference on Medical Image Computing and Computer-Assisted Intervention, 448–458 (Springer, 2024).
Li, J. et al. Dynamic graph representation with knowledge-aware attention for histopathology whole slide image analysis. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 11323–11332 (2024).
Li, G., Muller, M., Thabet, A. & Ghanem, B. Deepgcns: Can gcns go as deep as cnns?. In: Proceedings of the IEEE/CVF international conference on computer vision, 9267–9276 (2019).
Kipf, T. N. & Welling, M. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907 (2016).
Saltz, J. et al. Spatial organization and molecular correlation of tumor-infiltrating lymphocytes using deep learning on pathology images. Cell Reports 23, 181–193 (2018).
Graham, S. et al. Hover-net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images. Med. Image Analy. 58, 101563 (2019).
Shaban, M. et al. A novel digital score for abundance of tumour infiltrating lymphocytes predicts disease free survival in oral squamous cell carcinoma. Sci. Reports 9, 13341 (2019).
He, K., Zhang, X., Ren, S. & Sun, J. Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778 (2016).
Lu, M. Y. et al. A visual-language foundation model for computational pathology. Nat. Med. 30, 863–874 (2024).
Huang, Z., Bianchi, F., Yuksekgonul, M., Montine, T. J. & Zou, J. A visual-language foundation model for pathology image analysis using medical twitter. Nat. Med. 29, 2307–2316 (2023).
Beck, A. H. et al. Systematic analysis of breast cancer morphology uncovers stromal features associated with survival. Sci. Trans. Med. 3, 108ra113-108ra113 (2011).
Bremnes, R. M. et al. The role of tumor-infiltrating lymphocytes in development, progression, and prognosis of non-small cell lung cancer. J. Thoracic Oncol. 11, 789–800 (2016).
Acknowledgements
This work was funded by the NIH/National Cancer Institute R01 CA210360, NIH/National Cancer Institute U2C CA271898, the V Foundation, and the NIH/National Institute for Biomedical Imaging and Bioengineering R01 EB031993.
Author information
Authors and Affiliations
Contributions
R.D. and W. H. designed the study and conceived the experiments; R.D. conducted the experiments and analyzed the results; K.L. provided technical guidance to the model implementation; E.R. and A.S. provided clinical guidance; All authors reviewed the manuscript.
Corresponding author
Ethics declarations
Competing interests
W.H. receives research grant support from Early Diagnostics Inc. and consulting fees from LungLifeAI. All the remaining authors declare no conflict of interest.
Additional information
Publisher’s note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Ding, R., Luong, KD., Rodriguez, E. et al. Combining graph neural network and Mamba to capture local and global tissue spatial relationships in whole slide images. Sci Rep 15, 18261 (2025). https://doi.org/10.1038/s41598-025-99042-4
Received:
Accepted:
Published:
DOI: https://doi.org/10.1038/s41598-025-99042-4