pyLemur Walkthrough

pyLemur Walkthrough#

The goal of pyLemur is to simplify the analysis of multi-condition single-cell data. If you have collected a single-cell RNA-seq dataset with more than one condition, LEMUR predicts for each cell and gene how much the expression would change if the cell had been in the other condition.

pyLemur is a Python implementation of the LEMUR model; there is also an R package called lemur, which provides additional functionality: identifying neighborhoods of cells that show consistent differential expression values and a pseudo-bulk test to validate the findings.

pyLemur implements a novel framework to disentangle the effects of known covariates, latent cell states, and their interactions. At the core is a combination of matrix factorization and regression analysis implemented as geodesic regression on Grassmann manifolds. We call this latent embedding multivariate regression (LEMUR). For more details, see our preprint [1].

Schematic of the matrix decomposition at the core of LEMUR

Data#

For demonstration, I will use a dataset of interferon-\(\beta\) stimulated blood cells from Kang et al. [2].

# Standard imports
import numpy as np
import scanpy as sc
# pertpy is needed to download the Kang data
import pertpy 

# This will download the data to ./data/kang_2018.h5ad
adata = pertpy.data.kang_2018()
# Store counts separately in the layers
adata.layers["counts"] = adata.X.copy()
ryp2 is not installed. Install with pip install rpy2 to run tools with R support.
/home/docs/checkouts/readthedocs.org/user_builds/pylemur/envs/latest/lib/python3.10/site-packages/rich/live.py:231:
UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')


The data consists of \(24\,673\) cells and \(15\,706\) genes. The cells were measured in two conditions (label="ctrl" and label="stim"). The authors have annotated the cell type for each cell, which will be useful to analyze LEMUR’s results; however, note that the cell type labels are not used (and not needed) to fit the LEMUR model.

print(adata)
print(adata.obs)
AnnData object with n_obs × n_vars = 24673 × 15706
    obs: 'nCount_RNA', 'nFeature_RNA', 'tsne1', 'tsne2', 'label', 'cluster', 'cell_type', 'replicate', 'nCount_SCT', 'nFeature_SCT', 'integrated_snn_res.0.4', 'seurat_clusters'
    var: 'name'
    obsm: 'X_pca', 'X_umap'
    layers: 'counts'
                  nCount_RNA  nFeature_RNA      tsne1      tsne2 label  cluster        cell_type     replicate  nCount_SCT  nFeature_SCT integrated_snn_res.0.4 seurat_clusters
index                                                                                                                                                                          
AAACATACATTTCC-1      3017.0           877 -27.640373  14.966629  ctrl        9  CD14+ Monocytes  patient_1016      1704.0           711                    1                 1
AAACATACCAGAAA-1      2481.0           713 -27.493646  28.924885  ctrl        9  CD14+ Monocytes  patient_1256      1614.0           662                    1                 1
AAACATACCATGCA-1       703.0           337 -10.468194  -5.984389  ctrl        3      CD4 T cells  patient_1488       908.0           337                    6                 6
AAACATACCTCGCT-1      3420.0           850 -24.367997  20.429285  ctrl        9  CD14+ Monocytes  patient_1256      1738.0           653                    1                 1
AAACATACCTGGTA-1      3158.0          1111  27.952170  24.159738  ctrl        4  Dendritic cells  patient_1039      1857.0           928                   12                12
...                      ...           ...        ...        ...   ...      ...              ...           ...         ...           ...                  ...               ...
TTTGCATGCCTGAA-2      1033.0           468  18.268321   1.058202  stim        6      CD4 T cells  patient_1244      1128.0           468                    2                 2
TTTGCATGCCTGTC-2      2116.0           819 -11.563067   2.574095  stim        4          B cells  patient_1256      1669.0           799                    3                 3
TTTGCATGCTAAGC-2      1522.0           523  25.142392   6.603815  stim        6      CD4 T cells   patient_107      1422.0           523                    0                 0
TTTGCATGGGACGA-2      1143.0           503  14.359657  10.965601  stim        6      CD4 T cells  patient_1488      1185.0           503                    0                 0
TTTGCATGTCTTAC-2      1031.0           421  14.572118  -4.713942  stim        5      CD4 T cells  patient_1016      1144.0           419                    2                 2

[24673 rows x 12 columns]

Preprocessing#

LEMUR expects that the input has been variance-stabilized. Here, I will use the log-transformation as a simple, yet effective approach. In addition, I will only work on the \(1\,000\) most variable genes to make the results easier to manage.

# This follows the standard recommendation from scanpy 
sc.pp.normalize_total(adata, target_sum = 1e4, inplace=True)
sc.pp.log1p(adata)
adata.layers["logcounts"] = adata.X.copy()
sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger")
adata = adata[:, adata.var.highly_variable]
adata
View of AnnData object with n_obs × n_vars = 24673 × 1000
    obs: 'nCount_RNA', 'nFeature_RNA', 'tsne1', 'tsne2', 'label', 'cluster', 'cell_type', 'replicate', 'nCount_SCT', 'nFeature_SCT', 'integrated_snn_res.0.4', 'seurat_clusters'
    var: 'name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'
    obsm: 'X_pca', 'X_umap'
    layers: 'counts', 'logcounts'

If we make a 2D plot of the data using UMAP, we see that the cell types separate by treatment status.

sc.tl.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color=["label", "cell_type"])
../_images/dd083d434f65addd7d4db11ffafc9ba27e5b2570fb9d5154c951dc09de7f9ec6.png

LEMUR#

First, we import pyLemur; then, we fit the LEMUR model by providing the AnnData object, a specification of the experimental design, and the number of latent dimensions.

import pylemur
model = pylemur.tl.LEMUR(adata, design = "~ label", n_embedding=15)
model.fit()
model.align_with_harmony()
print(model)
Centering the data using linear regression.
Find base point
Fit regression on latent spaces
Find shared embedding coordinates
Alignment iteration 0
Alignment iteration 1
Alignment iteration 2
Alignment iteration 3
Alignment iteration 4
Alignment iteration 5
Converged
LEMUR model with 15 dimensions

To assess if the model was fit successfully, we plot a UMAP representation of the 15-dimensional embedding calculated by LEMUR. We want to see that the two conditions are well mixed in the embedding space because that means that LEMUR was able to disentangle the treatment effect from the cell type effect and that the residual variation is driven by the cell states.

# Recalculate the UMAP on the embedding calculated by LEMUR
adata.obsm["embedding"] = model.embedding
sc.pp.neighbors(adata, use_rep="embedding")
sc.tl.umap(adata)
sc.pl.umap(adata, color=["label", "cell_type"])
../_images/5873d6f166c8db91a8b4861bafe80f63e0e8dec9c7dbce006450009b18dfa4d7.png

The LEMUR model is fully parametric, which means that we can predict for each cell what its expression would have been in any condition (i.e., for a cell observed in the control condition, we can predict its expression under treatment) as a function of its low-dimensional embedding.

# The model.cond(**kwargs) call specifies the condition for the prediction
ctrl_pred = model.predict(new_condition=model.cond(label="ctrl"))
stim_pred = model.predict(new_condition=model.cond(label="stim"))

We can now check the predicted differential expression against the underlying observed expression patterns for individual genes. Here, I chose TSC22D3 as an example. The blue cells in the first plot are in neighborhoods with higher expression in the control condition than in the stimulated condition. The two other plots show the underlying gene expression for the control and stimulated cells and confirm LEMUR’s inference.

import matplotlib.pyplot as plt
adata.layers["diff"] = stim_pred - ctrl_pred
# Also try CXCL10, IL8, and FBXO40
sel_gene = "TSC22D3"

fsize = plt.rcParams['figure.figsize']
fig = plt.figure(figsize=(fsize[0] * 3, fsize[1])) 
axs = [fig.add_subplot(1, 3, i+1) for i in range(3)]
for ax in axs:
    ax.set_aspect('equal')
sc.pl.umap(adata, layer="diff", color=[sel_gene], cmap = plt.get_cmap("seismic"), vcenter=0,
    vmin=-4, vmax=4, title="Pred diff (stim - ctrl)", ax=axs[0], show=False)
sc.pl.umap(adata[adata.obs["label"]=="ctrl"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
    title="Ctrl expr", ax=axs[1], show=False)
sc.pl.umap(adata[adata.obs["label"]=="stim"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
    title="Stim expr", ax=axs[2])
../_images/8f976cc555e1ddc286fe302ba554c1e57176a062976d3a519c5ac7cecc83f42c.png

To assess the overall accuracy of LEMUR’s predictions, I will compare the average observed and predicted expression per cell type between conditions. The next plot simply shows the observed expression values. Genes on the diagonal don’t change expression much between conditions within a cell type, whereas all off-diagonal genes are differentially expressed:

def rowMeans_per_group(X, group):
    uniq = np.unique(group)
    res = np.zeros((len(uniq), X.shape[1]))
    for i, e in enumerate(uniq):
        res[i,:] = X[group == e,:].sum(axis=0) / sum(group == e)
    return res

adata_ctrl = adata[adata.obs["label"] == "ctrl",:]
adata_stim = adata[adata.obs["label"] == "stim",:]
ctrl_expr_per_cell_type = rowMeans_per_group(adata_ctrl.layers["logcounts"], adata_ctrl.obs["cell_type"])
stim_expr_per_cell_type = rowMeans_per_group(adata_stim.layers["logcounts"], adata_stim.obs["cell_type"])
obs_diff = stim_expr_per_cell_type - ctrl_expr_per_cell_type
plt.scatter(ctrl_expr_per_cell_type, stim_expr_per_cell_type, c = obs_diff,
    cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
plt.colorbar()
plt.title( "Inf-b stim. increases gene expression for many genes")
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
<matplotlib.lines.AxLine at 0x7f0884891f00>
../_images/f5c796261a55f2759ec1431ff69dc4046e173339e6759a66f1a665634bdd08bd.png

To demonstrate that LEMUR learned the underlying expression relations, I predict what the expression of cells from the control condition would have been had they been stimulated and compare the results against the observed expression in the stimulated condition. The closer the points are to the diagonal, the better the predictions.

stim_pred_per_cell_type = rowMeans_per_group(stim_pred[adata.obs["label"]=="ctrl"], adata_ctrl.obs["cell_type"])

plt.scatter(stim_expr_per_cell_type, stim_pred_per_cell_type, c = obs_diff,
    cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
plt.colorbar()
plt.title( "LEMUR's expression predictions are accurate")
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
<matplotlib.lines.AxLine at 0x7f0861833700>
../_images/4e9c2404f4e6131901aa97d4305841684b5a66e559e7d2fbf5528ca2d9c91cb6.png

Lastly, I directly compare the average predicted differential expression against the average observed differential expression per cell type. Again, the closer the points are to the diagonal, the better the predictions.

pred_diff = rowMeans_per_group(adata.layers["diff"], adata.obs["cell_type"])

plt.scatter(obs_diff, pred_diff, c = obs_diff,
    cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
plt.colorbar()
plt.title( "LEMUR's DE predictions are accurate")
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
<matplotlib.lines.AxLine at 0x7f086141bb80>
../_images/bf8b1302e76b95e346ca9b5e93373fd057e1ece638d7dd92e9015b9c6a0d79f8.png

Another advantage of LEMUR’s parametricity is that you could train the model on a subset of the data and then apply it to the full data.

I will demonstrate this by training the same LEMUR model on 5% of the original data, then transform the full data, and finally compare the first three dimensions of the embedding against the embedding from the model trained on the full model.

adata_subset = adata[np.random.choice(np.arange(adata.shape[0]), size = round(adata.shape[0] * 0.05)),]
model_small = pylemur.tl.LEMUR(adata_subset, design = "~ label", n_embedding=15)
model_small.fit().align_with_harmony()
emb_proj = model_small.transform(adata)
plt.scatter(emb_proj[:,0:3], model.embedding[:,0:3], s = 0.1)
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
plt.axline((0, 0), (-1, 1), linewidth=1, color='black')
Centering the data using linear regression.
Find base point
Fit regression on latent spaces
Find shared embedding coordinates
Alignment iteration 0
Alignment iteration 1
Alignment iteration 2
Alignment iteration 3
Converged
<matplotlib.lines.AxLine at 0x7f0862cca9b0>
../_images/c9f961e07d8ed873aa881ed302506101ccd648d805f3e50090d79d511031fdcf.png

We see that the small model still captures most of the relevant variation.

adata.obsm["embedding_from_small_fit"] = emb_proj
sc.pp.neighbors(adata, use_rep="embedding_from_small_fit")
sc.tl.umap(adata)
sc.pl.umap(adata, color=["label", "cell_type"])
../_images/2c80579b21cb538f8ce92ba33799bf882e3151ac77ebb5112dd98d19ca2af3b2.png

Session Info#

import session_info
session_info.show()
Click to view session information
-----
anndata             0.10.7
matplotlib          3.8.4
numpy               1.26.4
pandas              2.2.2
pertpy              0.7.0
pylemur             0.2.1
scanpy              1.10.1
session_info        1.0.0
-----
Click to view modules imported as dependencies
PIL                 10.3.0
absl                NA
adjustText          1.1.1
arrow               1.3.0
arviz               0.18.0
asttokens           NA
attr                23.2.0
blitzgsea           NA
certifi             2024.02.02
chardet             5.2.0
charset_normalizer  3.3.2
chex                0.1.86
comm                0.2.2
contextlib2         NA
custom_inherit      2.4.1
cycler              0.12.1
cython_runtime      NA
dateutil            2.9.0.post0
debugpy             1.8.1
decorator           5.1.1
decoupler           1.6.0
docrep              0.3.2
etils               1.7.0
exceptiongroup      1.2.1
executing           2.0.1
flax                0.8.2
formulaic           1.0.1
fsspec              2024.3.1
graphlib            NA
h5py                3.11.0
harmonypy           NA
idna                3.7
igraph              0.11.4
importlib_resources NA
interface_meta      1.3.0
ipykernel           6.29.4
jax                 0.4.26
jaxlib              0.4.26
jaxopt              NA
jedi                0.19.1
joblib              1.4.0
kiwisolver          1.4.5
legacy_api_wrap     NA
leidenalg           0.10.2
lightning           2.1.4
lightning_fabric    2.2.2
lightning_utilities 0.11.2
llvmlite            0.42.0
matplotlib_inline   0.1.7
ml_collections      NA
ml_dtypes           0.4.0
mpl_toolkits        NA
mpmath              1.3.0
msgpack             1.0.8
mudata              0.2.3
multipledispatch    0.6.0
natsort             8.4.0
numba               0.59.1
numpyro             0.14.0
opt_einsum          v3.3.0
optax               0.2.2
ott                 0.4.6
packaging           24.0
parso               0.8.4
patsy               0.5.6
pkg_resources       NA
platformdirs        4.2.0
ply                 3.11
png                 0.20220715.0
prompt_toolkit      3.0.43
psutil              5.9.8
pubchempy           1.0.4
pure_eval           0.2.2
pyarrow             16.0.0
pydev_ipython       NA
pydevconsole        NA
pydevd              2.9.5
pydevd_file_utils   NA
pydevd_plugins      NA
pydevd_tracing      NA
pygments            2.17.2
pynndescent         0.5.12
pyomo               6.7.1
pyparsing           3.1.2
pyro                1.9.0
pytorch_lightning   2.2.2
pytz                2024.1
reportlab           4.2.0
requests            2.31.0
rich                NA
scipy               1.13.0
scvi                1.1.2
seaborn             0.13.2
six                 1.16.0
sklearn             1.4.2
skmisc              0.3.1
sparsecca           0.3.1
sphinxcontrib       NA
stack_data          0.6.3
statsmodels         0.14.2
texttable           1.7.0
threadpoolctl       3.4.0
toolz               0.12.1
torch               2.2.2+cu121
torchgen            NA
torchmetrics        1.3.2
tornado             6.4
toyplot             1.0.3
toytree             2.0.5
tqdm                4.66.2
traitlets           5.14.3
tree                0.1.8
typing_extensions   NA
umap                0.5.6
urllib3             2.2.1
wcwidth             0.2.13
wrapt               1.16.0
xarray              2024.3.0
xarray_einstats     0.7.0
yaml                6.0.1
zmq                 26.0.2
zoneinfo            NA
-----
IPython             8.23.0
jupyter_client      8.6.1
jupyter_core        5.7.2
-----
Python 3.10.13 (main, Feb  1 2024, 17:18:41) [GCC 9.4.0]
Linux-5.19.0-1028-aws-x86_64-with-glibc2.31
-----
Session information updated at 2024-04-23 11:52

References#

[1]

Constantin Ahlmann-Eltze and Wolfgang Huber. Analysis of multi-condition single-cell data with latent embedding multivariate regression. bioRxiv, 02 2024. URL: http://dx.doi.org/10.1101/2023.03.06.531268.

[2]

Hyun Min Kang, Meena Subramaniam, Sasha Targ, Michelle Nguyen, Lenka Maliskova, Elizabeth McCarthy, Eunice Wan, Simon Wong, Lauren Byrnes, Cristina M Lanata, Rachel E Gate, Sara Mostafavi, Alexander Marson, Noah Zaitlen, Lindsey A Criswell, and Chun Jimmie Ye. Multiplexed droplet single-cell rna-sequencing using natural genetic variation. Nature Biotechnology, 36(1):89–94, 01 2018. URL: http://www.nature.com/articles/nbt.4042, doi:10.1038/nbt.4042.