---
file_format: mystnb
mystnb:
    execution_timeout: 600
kernelspec:
  name: python3
---

# pyLemur Walkthrough


The goal of `pyLemur` is to simplify the analysis of multi-condition single-cell data. If you have collected a single-cell RNA-seq dataset with more than one condition, LEMUR predicts for each cell and gene how much the expression would change if the cell had been in the other condition.

`pyLemur` is a Python implementation of the LEMUR model; there is also an `R` package called [lemur](https://bioconductor.org/packages/lemur/), which provides additional functionality: identifying neighborhoods of cells that show consistent differential expression values and a pseudo-bulk test to validate the findings.

`pyLemur` implements a novel framework to disentangle the effects of known covariates, latent cell states, and their interactions. At the core is a combination of matrix factorization and regression analysis implemented as geodesic regression on Grassmann manifolds. We call this latent embedding multivariate regression (LEMUR). For more details, see our [preprint](https://www.biorxiv.org/content/10.1101/2023.03.06.531268) {cite:p}`Ahlmann-Eltze2024`.

<img src="../_static/images/equation_schematic.png" alt="Schematic of the matrix decomposition at the core of LEMUR" />


## Data

For demonstration, I will use a dataset of interferon-$\beta$ stimulated blood cells from {cite:t}`kang2018`.

```{code-cell} ipython3
---
output_stderr: remove
---
# Standard imports
import numpy as np
import scanpy as sc
# pertpy is needed to download the Kang data
import pertpy

# This will download the data to ./data/kang_2018.h5ad
adata = pertpy.data.kang_2018()
# Store counts separately in the layers
adata.layers["counts"] = adata.X.copy()
```

The data consists of $24\,673$ cells and $15\,706$ genes. The cells were measured in two conditions (`label="ctrl"` and `label="stim"`). The authors have annotated the cell type for each cell, which will be useful to analyze LEMUR's results; however, note that the cell type labels are not used (and not needed) to fit the LEMUR model.

```{code-cell} ipython3
:tags: ["remove-cell"]
import pandas as pd
pd.options.display.width = 200
pd.options.display.max_colwidth = 20
```

```{code-cell} ipython3
print(adata)
print(adata.obs)
```

## Preprocessing

LEMUR expects that the input has been variance-stabilized. Here, I will use the log-transformation as a simple, yet effective approach.
In addition, I will only work on the $1\,000$ most variable genes to make the results easier to manage.
```{code-cell} ipython3
# This follows the standard recommendation from scanpy
sc.pp.normalize_total(adata, target_sum = 1e4, inplace=True)
sc.pp.log1p(adata)
adata.layers["logcounts"] = adata.X.copy()
sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger")
adata = adata[:, adata.var.highly_variable]
adata
```

If we make a 2D plot of the data using UMAP, we see that the cell types separate by treatment status.
```{code-cell} ipython3
sc.tl.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color=["label", "cell_type"])
```


## LEMUR

First, we import `pyLemur`; then, we fit the LEMUR model by providing the `AnnData` object, a specification of the experimental design, and the number of latent dimensions.

```{code-cell} ipython3
import pylemur
model = pylemur.tl.LEMUR(adata, design = "~ label", n_embedding=15)
model.fit()
model.align_with_harmony()
print(model)
```

To assess if the model was fit successfully, we plot a UMAP representation of the 15-dimensional embedding calculated by LEMUR. We want to see that the two conditions are well mixed in the embedding space because that means that LEMUR was able to disentangle the treatment effect from the cell type effect and that the residual variation is driven by the cell states.
```{code-cell} ipython3
# Recalculate the UMAP on the embedding calculated by LEMUR
adata.obsm["embedding"] = model.embedding
sc.pp.neighbors(adata, use_rep="embedding")
sc.tl.umap(adata)
sc.pl.umap(adata, color=["label", "cell_type"])
```

The LEMUR model is fully parametric, which means that we can predict for each cell what its expression would have been in any condition (i.e., for a cell observed in the control condition, we can predict its expression under treatment) as a function of its low-dimensional embedding.

```{code-cell} ipython3
# The model.cond(**kwargs) call specifies the condition for the prediction
ctrl_pred = model.predict(new_condition=model.cond(label="ctrl"))
stim_pred = model.predict(new_condition=model.cond(label="stim"))
```

We can now check the predicted differential expression against the underlying observed expression patterns for individual genes. Here, I chose _TSC22D3_ as an example. The blue cells in the first plot are in neighborhoods with higher expression in the control condition than in the stimulated condition. The two other plots show the underlying gene expression for the control and stimulated cells and confirm LEMUR's inference.
```{code-cell} ipython3
import matplotlib.pyplot as plt
adata.layers["diff"] = stim_pred - ctrl_pred
# Also try CXCL10, IL8, and FBXO40
sel_gene = "TSC22D3"

fsize = plt.rcParams['figure.figsize']
fig = plt.figure(figsize=(fsize[0] * 3, fsize[1]))
axs = [fig.add_subplot(1, 3, i+1) for i in range(3)]
for ax in axs:
    ax.set_aspect('equal')
sc.pl.umap(adata, layer="diff", color=[sel_gene], cmap = plt.get_cmap("seismic"), vcenter=0,
    vmin=-4, vmax=4, title="Pred diff (stim - ctrl)", ax=axs[0], show=False)
sc.pl.umap(adata[adata.obs["label"]=="ctrl"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
    title="Ctrl expr", ax=axs[1], show=False)
sc.pl.umap(adata[adata.obs["label"]=="stim"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
    title="Stim expr", ax=axs[2])
```

To assess the overall accuracy of LEMUR's predictions, I will compare the average observed and predicted expression per cell type between conditions. The next plot simply shows the observed expression values. Genes on the diagonal don't change expression much between conditions within a cell type, whereas all off-diagonal genes are differentially expressed:
```{code-cell} ipython3
def rowMeans_per_group(X, group):
    uniq = np.unique(group)
    res = np.zeros((len(uniq), X.shape[1]))
    for i, e in enumerate(uniq):
        res[i,:] = X[group == e,:].sum(axis=0) / sum(group == e)
    return res

adata_ctrl = adata[adata.obs["label"] == "ctrl",:]
adata_stim = adata[adata.obs["label"] == "stim",:]
ctrl_expr_per_cell_type = rowMeans_per_group(adata_ctrl.layers["logcounts"], adata_ctrl.obs["cell_type"])
stim_expr_per_cell_type = rowMeans_per_group(adata_stim.layers["logcounts"], adata_stim.obs["cell_type"])
obs_diff = stim_expr_per_cell_type - ctrl_expr_per_cell_type
plt.scatter(ctrl_expr_per_cell_type, stim_expr_per_cell_type, c = obs_diff,
    cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
plt.colorbar()
plt.title( "Inf-b stim. increases gene expression for many genes")
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
```

To demonstrate that LEMUR learned the underlying expression relations, I predict what the expression of cells from the control condition would have been had they been stimulated and compare the results against the observed expression in the stimulated condition. The closer the points are to the diagonal, the better the predictions.
```{code-cell} ipython3
stim_pred_per_cell_type = rowMeans_per_group(stim_pred[adata.obs["label"]=="ctrl"], adata_ctrl.obs["cell_type"])

plt.scatter(stim_expr_per_cell_type, stim_pred_per_cell_type, c = obs_diff,
    cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
plt.colorbar()
plt.title( "LEMUR's expression predictions are accurate")
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
```

Lastly, I directly compare the average predicted differential expression against the average observed differential expression per cell type. Again, the closer the points are to the diagonal, the better the predictions.

```{code-cell} ipython3
pred_diff = rowMeans_per_group(adata.layers["diff"], adata.obs["cell_type"])

plt.scatter(obs_diff, pred_diff, c = obs_diff,
    cmap = plt.get_cmap("seismic"), vmin=-5, vmax=5, marker="o",edgecolors= "black")
plt.colorbar()
plt.title( "LEMUR's DE predictions are accurate")
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
```

Another advantage of LEMUR's parametricity is that you could train the model on a subset of the data and then apply it to the full data.

I will demonstrate this by training the same LEMUR model on 5% of the original data, then `transform` the full data, and finally compare the first three dimensions of the embedding against the embedding from the model trained on the full model.

```{code-cell} ipython3
adata_subset = adata[np.random.choice(np.arange(adata.shape[0]), size = round(adata.shape[0] * 0.05)),]
model_small = pylemur.tl.LEMUR(adata_subset, design = "~ label", n_embedding=15)
model_small.fit().align_with_harmony()
emb_proj = model_small.transform(adata)
plt.scatter(emb_proj[:,0:3], model.embedding[:,0:3], s = 0.1)
plt.axline((0, 0), (1, 1), linewidth=1, color='black')
plt.axline((0, 0), (-1, 1), linewidth=1, color='black')
```

We see that the small model still captures most of the relevant variation.
```{code-cell} ipython3
adata.obsm["embedding_from_small_fit"] = emb_proj
sc.pp.neighbors(adata, use_rep="embedding_from_small_fit")
sc.tl.umap(adata)
sc.pl.umap(adata, color=["label", "cell_type"])
```

### Session Info

```{code-cell} ipython3
import session_info
session_info.show()
```


### References

```{bibliography}
:style: plain
:filter: docname in docnames
```
