Calculating attributions using XAI and scRNA-seq FMs
Several methods exist for calculating feature attributions on deep learning models18,19,20,54,69,70. These methods can be broadly categorized into two groups: gradient-based and perturbation-based. Gradient-based methods such as IxG19,69, IG18 and DL20 quantify feature relevance by calculating gradients between the input features and the model output. In contrast, perturbation-based methods such as SHAP54 and LIME70 measure relevance by quantifying the effect on model outputs after altering sets of input features. This work focuses on gradient-based explainability methods because of their computational efficiency. Specifically, gradient-based approaches scale with the number of integration steps (m) rather than the number of input features (for example, IG requires only O(m) backward passes per output). By contrast, perturbation-based approaches require repeated evaluations of the model on perturbed inputs; exact Shapley computation is exponential in the number of features (p), whereas KernelSHAP typically scales as O(p2)71,72. Consequently, perturbation methods become impractical for high-dimensional settings such as single-cell transcriptomic data.
scRNA-seq FMs are often trained in an unsupervised or semisupervised manner to generate biologically meaningful high-dimensional embedding spaces. Computing attribution scores for each latent dimension is computationally intensive, limiting the feasibility of these methods in such contexts. Here, building on our previous work21, an efficient variation of IG for multidimensional embeddings was implemented using Captum73. This approach adds a final layer to the original scRNA-seq FM, which calculates a weighted sum of each embedding dimension to yield a single output. This is more efficient than calculating gradients against each latent variable independently and places greater emphasis on the dimensions that have large magnitudes in the cell’s embedding. Formally, let \(f:{X}({R}^{n})\,\to {H}({R}^{{d}_{h}})\) represent the scRNA-seq FM. The attribution score for the ith dimension of an input \(X\) is defined as follows:
$${\mathrm{Att}}_{i}\left(\mathop{\sum }\limits_{k=1}^{{d}_{{hy}}}\left[{f}_{k}\left(X\right) \times \,{f}_{k}\left(\cdot \right)\right],\,X\right)$$
Building on this framework, three gradient-based attribution methods were applied to compute feature-level attributions against the final summation layer. IxG, an extension of the saliency method19, calculates gradients for each feature against the model output and DL20 provides a more refined estimate by comparing each neuron’s activation to a reference activation and propagating contribution scores from the output back to the input. Finally, IG18 computes attributions by interpolating between a baseline and the observed cell state (Fig. 2a), calculating gradients for each interpolated input, and integrating these gradients to obtain the final attribution score. For both DL and IG, a zero vector was chosen as the baseline, following common practice in image-classification tasks, to ensure consistent comparisons across cells without dataset-specific tuning. For all three methods, the attribution values are then multiplied by the gene expression vector (preprocessed to be compatible with a given FM) and then renormalized to a standard scale factor (that is, 1,000) to yield the final score for each gene, quantifying its influence on the cell’s representation in the embedding. Leveraging the SCimilarity model14 and a batch size of 500 cells, this framework processes approximately 775 cells per s with IG and 3,500 cells per s with DL on a single L40S A100 GPU, illustrating the practical efficiency of these gradient-based approaches for large-scale single-cell data.
Benchmarking attributions across scRNA-seq FMs
Gradient-based measures of gene importance were implemented for five FMs: SCimilarity14 (version 2023_01_rep0), scGPT13 (whole-human model), scFoundation12, an SCVI-based9 model trained on CELLxGENE Census data17 (2023-12-15-scvi-homo-sapiens) and a self-supervised model implemented with random masking on the scTab dataset15,16. Additionally, these approaches were applied to two fine-tuned models15 that were originally trained on the scTab dataset and then retrained using self-supervised learning on Tabula Sapiens dataset74 with the goals of cell type classification or gene reconstruction. Attribution scores were calculated with Captum73 using DL, IG and IxG. For each FM, attributions were calculated against the primary embedding layer. For the fine-tuned classification model, attributions were calculated against the second to last network layer.
The attributions collected from these models were compared against each other on multiple tasks assessing their computational performance, biological relevance and resistance to technical artifacts. Two additional baselines were included: log-normalized expression and log-normalized expression using shuffled gene labels, averaged over 25 randomizations. Most tasks used a downsampled version of the Wells et al.59 dataset (n = 2,943 cells), sampling at most 15 cells from each donor for each of the 13 cell types that have matching gene sets from CIBERSORT75,76 (Supplementary Table 2). Additionally, one task considered fibroblasts annotated with cell-cycle state from Riba et al.60. Speed was assessed by measuring the time required to generate gene attributions for a single, preprocessed cell on an L40S GPU and repeated ten times (Fig. 2).
To evaluate robustness, the mean value (either attribution or log-normalized expression) of established marker genes from CIBERSORT75,76 was calculated for each cell type (Supplementary Table 2) and its absolute Spearman correlation with the cell’s total counts was determined. Two marker gene analyses were also performed using within-cell ranking. Within each cell, nonzero gene values were rank-normalized from 0 to 100 (scipy77 ‘rankdata’) and, for each cell type, the average rank percentile was computed per gene across cells in which it was detected. The mean value for all genes in each gene set (for example, ribosomal genes or cell type markers) is reported in Fig. 2. Lastly, cell-cycle analysis was performed using the G2/M phase genes from Tirosh et al.39. A G2/M score was calculated for each cell as the mean value of these genes and Fig. 2 reports the log2 fold change of this score for G2M-annotated cells relative to all other cells.
Each model uses a unique fixed set of genes and attributions can only be calculated for those genes included in that specific model’s input space and the corresponding dataset of interest. Figure 2 shows results including all genes that each model can use, with the exception of scGPT that used 3,000 highly variable genes because of memory limitations on the GPU. Extended Data Fig. 2 shows the same analysis but only using the included in the input space for all five FMs (3,497 and 1,395 for the Wells et al.59 and Riba et al.60 studies respectively).
Datasets and gene sets
A total of 24 publicly available datasets were used for focused analyses in this paper (Supplementary Table 1 and Extended Data Fig. 1). Many tests in this paper used established gene sets associated with cell types or molecular phenotypes that are detailed in Supplementary Table 2.
Plotting and statistics
Plotting was performed in python using seaborn78 and matplotlib79. Statistical tests were implemented using scipy77 and statsmodels80.
Top attributions per cell type
The Adams et al. lung dataset22 includes the 16 cell types that have a corresponding lung marker gene set from MSigDB2,81,82 (Supplementary Table 2). For each of these 16 cell types, the average log-normalized gene expression and the average attribution values were calculated for every gene. The average rank analysis (Fig. 3d,e) was performed as for the benchmarking analysis described above. Figure 3d shows the summary for four gene families: lung cell type marker genes from MSigDB2,81, ribosomal genes annotated from KEGG83,84, mitochondrial genes (starting with MT-) and marker genes from MSigDB that are also TFs85. The values for each individual cell type and gene family are shown in Extended Data Fig. 3a.
Correlation analysis
For each of the 16 cell types analyzed from Adams et al.22, the average log-normalized expression and attribution scores were collected for marker genes in each cell. The absolute Spearman correlation was then calculated between these values and three scRNA-seq quality control metrics: number of unique molecular identifiers (UMIs), number of unique genes and percentage of UMIs from mitochondrial genes (Fig. 3g,h and Extended Data Fig. 3b).
Robustness analysis
Each B cell from Adams et al.22 was downsampled by randomly removing a percentage of the counts in each cell (0%, 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80% and 90%). These downsampled datasets were then log-normalized and used to calculate attribution scores. For each downsampling percentage, the average attribution score for each gene was calculated across all B cells and ranked in descending order.
Cross-study NMF and topic modeling
NMF was implemented using scikit-learn86 on author-annotated T cells from three studies: Adams et al.22, Cano-Gamez et al.26 and Deng et al.28. Analysis included 14,056 genes detected across all three studies. NMF was performed 200 times for both expression and attributions, varying three parameters: number of factors (5, 10, 15, 20 and 25), random state (114–133) and feature set (all 14,056 genes or 2,591 highly variable genes).
For each run, a ‘Treg-associated factor’ was identified by calculating t-statistics between Tregs and non-Tregs across all nine possible within-study and cross-study comparisons (scipy ‘ttest_ind’, alternative = ‘greater’). The factor with the highest minimum t-statistic across these nine comparisons was selected as the potential Treg-associated factor. The Treg-associated factor was considered to robustly highlight Tregs if that the worst-performing Treg versus non-Treg comparison was significant, using a Bonferroni-adjusted P value of 2.7 × 10−6 (α = 0.01 divided by 3,600 tests: 400 trials, each with nine comparisons). Figure 3i–k shows results from a specific parameter set (ten factors, all genes, random state = 114), although similar trends were observed across most parameter combinations.
Generalization analysis was performed using 3,200 cells from the SCimilarity database14 that were not used for model training or validation. A total of 16 tissues were selected because they contained at least 100 Tregs and 100 CD4+ T cells, as predicted by SCimilarity (nn_prediction_dist < 0.02). The NMF models trained on the original T cell cohort were applied to this independent dataset to assess generalization (Extended Data Fig. 4e,f) and a two-sided paired t-test was performed comparing the mean Treg score and mean CD4+ score for each tissue.
Additional validation was performed on two independent studies downloaded from CELLxGENE: the HLCA31 and a single-cell atlas of the ocular surface32. For the HLCA, the dataset was downsampled to include at most five cells of each cell type (annotation level: 5) from each sample not included in the SCimilarity training set (n = 34,439). Genes were filtered to those present in at least 0.5% of cells (n = 15,133). For the ocular surface atlas, all epithelial cells (n = 12,990) were included and genes were similarly filtered to those detected in at least 0.5% of cells (n = 14,816). Attributions for both datasets were then calculated using the SCimilarity and IG. NMF was performed using the same 200 parameter settings as the T cell analysis. Additionally, topic modeling was performed using the scETM35 and the ‘amortized LDA’ method scvi-tools33. scETM was run for 500 epochs with and without batch supervision using the ‘enable_batch_bias’ parameter. Otherwise, both scETM and amortized LDA were implemented with default parameters across the same range of topics (k = 5, 10, 15, 20 and 25).
The statsmodels80 package was used to fit a linear model (FACTOR_SCORE ~ C(cell_type)) to test whether the transformed NMF or LDA values could be explained by cell type. The ‘ann_level_3’ and ‘celltype_legency’ columns were used for cell type annotations in the HLCA and ocular datasets, respectively. For each parameter set, factors or topics with an adjusted R2 value greater than 0.25 were considered as cell type associated. This same analysis was then repeated to determine whether factor usage was predicted by the cell’s study of origin.
Spatial transcriptomics analysis
Slide-TAG spatial transcriptomics analysis used tonsil data from Russell et al.41. Cell type labels were combined into five large immune cell categories: B cell (B_germinal_center, B_naive and B_memory), T cell (T_CD4, T_follicular_helper, T_CD8 snd T_double_neg), plasma (plasma), NK (NK) and myeloid (mDC, myeloid and pDC). Simulated Visium spots were created by generating a grid of 25 micrometer spots and aggregating counts from cells within the border of each spot, converting 3,199 cells into 625 spots. Attributions were then calculated on log-normalized counts from the single cells and the simulated spots using SCimilarity and IG. Regression plots in Extended Data Fig. 7d were created with the ‘regplot’ function from seaborn78 using the default parameters.
Dataset specificity and scanpy scoring
The ‘score_genes’ function from Scanpy was implemented on two datasets: Deng et al.28 and Cano-Gomez et al.26. The gene set queried included all isoforms of CD3 (CD3D, CD3E, CD3G and CD247) and of CD8 (CD8A and CD8B). Density plots were created using the ‘kde_plot’ function in seaborn78 with standard parameters. For Extended Data Fig. 6b, ‘score_genes’ produces dataset-dependent results for cDC2 marker genes87 on cDC2 cells depending on whether the function considers the full dataset22 or a limited version containing only dendritic cells and mast cells.
Implementing established gene set scoring methods
Gene set activation was evaluated by comparing mean attributions against five existing methods: adjusted neighborhood scoring (ANS)37, JASMINE38, Mean Expression, Scanpy4,39 and UCell36. For all methods, data were log-normalized in Scanpy with a target sum of 10,000. Scores were calculated separately for each dataset to accommodate methods using background distributions.
‘Mean expression’ averages the log-normalized expression for genes of interest. Scanpy scoring uses the ‘score_genes’ function with default parameters, which is based on the Tirosh et al.5,39 approach, comparing query gene expression against background genes with similar expression levels. ANS37 also uses control genes but identifies backgrounds through gene expression neighbor graphs rather than expression buckets. UCell36 and JASMINE38 are both rank-based gene set scoring approaches; UCell calculates Mann–Whitney U statistics from expression ranks36, while JASMINE computes likelihood or odds ratios from within-cell gene ranks38. ANS, UCell and JASMINE were implemented using the ANS_signature_scoring package37, with ANS using a control_size of 50 and JASMINE using the ‘likelihood’ model.
Gene set scoring across SCimilarity validation dataset
Immune gene sets from CellMarker (version 2.0)88 and CIBERSORT75,76 were applied to PBMCs from van der Wijst et al.40. Seven cell types with matching gene sets from both sources were analyzed (Extended Data Fig. 6a). For each cell, consensus scores were created by averaging marker gene scores from both sources for the same cell type. These consensus scores were then averaged by annotated cell type identity and displayed in Fig. 4a.
For analysis across all 15 SCimilarity datasets, B cells and mast cells were considered as these populations appeared in the most unique studies (seven and six, respectively). CIBERSORT75,76 signatures were used to query these immune cell populations across datasets (Fig. 4d). Lung and kidney epithelial cell signatures were collected from CellMarker (version 2.0)88 (Fig. 4e,f) and dendritic cells were assessed using both CellMarker (version 2.0) and from Table 1 of Collin et al.87 (Extended Data Fig. 6c,d). Extended Data Fig. 6e–g was generated using cell-cycle gene sets for S and G2/M phases from Tirosh et al.39 to show that attributions can identify these cell states across cell types, including annotated fibroblasts60.
Supervised and unsupervised cell type prediction
Cell type classification performance was evaluated using six gene set scoring methods: ANS37, JASMINE38, mean attribution, mean expression, Scanpy4,39 and UCell36. For each cell type from the Adams et al. dataset22 and each method, marker gene set scores (Supplementary Table 2) were used to train univariate logistic regression classifiers using LogisticRegressionCV from scikit-learn86 with balanced class weights and a random state of 114. Performance was assessed across 25 independent train–test splits (80:20) using F1 score to account for class imbalance.
Unsupervised classification used Otsu’s multithresholding method89, using the ‘threshold_multiotsu’ function from scikit-image90. For each cell type, marker gene set scores were calculated for each cell and optimal thresholds determined. Multiple threshold numbers (n = 1, 2 or 3) were evaluated, with optimal n selected by maximizing Cohen’s d effect size between the top group and second-highest group, ensuring robust separation between populations. The cells above the chosen threshold were considered ‘hits’ and used to calculate an F1 score against the true labels (Extended Data Fig. 8).
MS1 full dataset queries
The MS1 phenotype in monocytes was established through NMF analysis by Reyes et al.10. Attributions were queried across the SCimilarity cell database14 using the top 99 genes associated with this factor to identify other cells activated for this phenotype. This search queries 22 million cells across 412 datasets, which was then reduced to 2.3 million in vivo cells from 244 studies that were confidently predicted as monocytes or macrophages (SCimilarity prediction_nn_dist < 0.02). For sample-level hit analysis (Fig. 4a), samples were considered if they contained at least 25 monocytes and macrophages and were from a disease with at least three unique samples. Hits were called as the top 10% of monocytes or macrophages by MS1 score and the percentage of hits was calculated for each sample (Fig. 4a). Other thresholds (5%, 15% and 20%) were also tested and showed similar top diseases (Extended Data Fig. 9b).
MS1 cells in individual datasets
The MS1+ hits from the SCimilarity database were analyzed in the context of their respective datasets. Individual analyses were performed for multiple diseases using metadata from the original study: COVID-19 from Stephenson et al.42, HLH from Liu et al.47, SFTS from Park et al.46 and KD from Wang et al.48.
Consensus NMF was run on scRNA-seq data from the 8,282 monocytes and macrophages from Wang et al.48 using the cNMF python package27 to compare results against COVID-19 and sepsis samples processed in Reyes et al.10. The expression data from these cells were run through the cNMF package using the top 3,000 variables genes, considering k values from 5 to 20 and repeating each trial 25 times. The optimal k was chosen as 14 on the basis of the balance between high stability and relatively low error (Extended Data Fig. 10a).
The standardized gene loadings from the KD monocytes were compared against the standardized gene loadings from Reyes et al.10 by computing pairwise Pearson correlation coefficients. These values were clustered using the seaborn78 clustermap function with the ‘correlation’ metric (Fig. 5e and Extended Data Fig. 10b) For each individual, a one-sided Mann–Whitney U-test was performed comparing the usage values for cells collected before IVIG against those collected after IVIG. This was repeated for each NMF factor and P values were corrected using the Benjamini–Hochberg procedure.
Material data collection, inclusion and ethics
All experiments in this study involving human tissue or data were conducted in accordance with the Declaration of Helsinki. For the induction experiment, peripheral blood samples were obtained from participants evaluated at Boston Children’s Hospital; consent was obtained from participants or their parents or guardians under Institutional Review Board protocol X10-01-0308.
MS1 induction experiments
Differentiation of myeloid cells from CD34+ bone marrow HSPCs was performed as previously described10. HSPCs cells were cultured in SFEM II with 75 nM StemRegenin 1, 3.5 nM UM729, 1× CC110 (StemCell Technologies) and 1× penicillin–streptomycin (Gibco). To initiate differentiation, cells were cultured in the same medium supplemented with 20% pooled healthy human serum (SeraCare) or serum from diseased individuals. Serum samples from individuals with multiple myeloma were obtained from BioIVT.
To assess MS1 induction, cells were stained with the following panel: CD14–FITC (clone M5E2; BioLegend), CD15–APC (clone W6D3; BioLegend), CD11b–AF700 (clone ICRF44; BioLegend), CD34–BV650 (clone 561; BioLegend), HLA-DR–PE/Cy7 (clone L243; BioLegend) and IL-1R2–PE (clone 34141; Thermo). Cells were resuspended in FACS buffer with 5% CountBright beads (Invitrogen) to allow the determination of absolute counts during analysis. Flow cytometry data were acquired on an LSR Fortessa (BD Biosciences) and analyzed using FlowJo version 10.10. For analysis of HLAlow monocytes, the average HLA intensity value was calculated for internal controls and then, for each sample, the percentage of monocytes with HLA intensity below that value was calculated. Two-sided mixed linear analysis, treating the individual as a random effect, was performed using the statsmodels80 package in Python.
Reporting summary
Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.