A wrapper to run trVAE
on multi-layered Seurat V5 object.
Requires a conda environment with scArches and necessary dependencies
Recommendations: use raw counts (except for recon.loss = "mse"
)
and all features (features = Features(object), layers = "counts",
scale.layer = NULL
).
Usage
trVAEIntegration(
object,
orig = NULL,
groups = NULL,
groups.name = NULL,
surgery.name = NULL,
surgery.sort = TRUE,
features = NULL,
layers = ifelse(recon.loss == "mse", "data", "counts"),
scale.layer = "scale.data",
conda_env = NULL,
new.reduction = "integrated.trVAE",
reduction.key = "trVAElatent_",
torch.intraop.threads = 4L,
torch.interop.threads = NULL,
model.save.dir = NULL,
ndims.out = 10L,
recon.loss = c("nb", "zinb", "mse"),
hidden_layer_sizes = c(256L, 64L),
dr_rate = 0.05,
use_mmd = TRUE,
mmd_on = c("z", "y"),
mmd_boundary = NULL,
beta = 1,
use_bn = FALSE,
use_ln = TRUE,
n_epochs = 400L,
lr = 0.001,
eps = 0.01,
hide.py.warn = T,
seed.use = 42L,
verbose = TRUE,
...
)
Arguments
- object
A
Seurat
object (or anAssay5
object if not called byIntegrateLayers
)- orig
DimReduc
object. Not to be set directly when called withIntegrateLayers
, useorig.reduction
argument instead- groups
A named data frame with grouping information. Can also contain surgery groups to perform surgery integration.
- groups.name
Column name from
groups
data frame that stores grouping information. Ifgroups.name = NULL
, the first column is used- surgery.name
Column name from
groups
data frame that stores surgery information. Ifsurgery.name = NULL
, a one shot integration is performed- surgery.sort
change the order in which surgery groups are integrated. By default (
surgery.sort = TRUE
), surgery groups are ordered by name. WhenFALSE
, each group is integrated in the order of first occurrence in the columnsurgery.name
- features
Vector of feature names to input to the integration method. When
features = NULL
(default), theVariableFeatures
are used. To pass all features, use the output ofFeatures()
- layers
Name of the layers to use in the integration
- scale.layer
Name of the scaled layer in
Assay
- conda_env
Path to conda environment to run trVAE (should also contain the scipy python module). By default, uses the conda environment registered for trVAE in the conda environment manager
- new.reduction
Name of the new integrated dimensional reduction
- reduction.key
Key for the new integrated dimensional reduction
- torch.intraop.threads
Number of intra-op threads available to torch when training on CPU instead of GPU. Set via
torch.set_num_threads()
.- torch.interop.threads
Number of intra-op threads available to torch when training on CPU instead of GPU. Set via
torch.set_num_interop_threads()
. Can only be changed once, on first call.- model.save.dir
Path to a directory to save the model(s) to. Uses
TRVAE.save()
. Does not save anndata.model.save.dir = NULL
(default) disables saving the model(s).- ndims.out
Number of dimensions for
new.reduction
output. Corresponds tolatent_dim
argument in the original API of TRVAE from scArches- recon.loss
Definition of Reconstruction-Loss-Method. One of 'mse', 'nb' or 'zinb' (hence mean squared error, negative binomial and zero-inflated negative binomial respectively). Recommended to set
layer = "data"
for 'mse' (andlayer = "counts"
for (zi)nb)Hidden layer sizes for encoder network
- dr_rate
Dropout rate applied to all layers.
dr_rate = 0
disables dropout.- use_mmd
Whether an additional MMD loss is to be calculated on the latent dim. (see next argument)
- mmd_on
Choose on which layer MMD loss will be calculated on. One of 'z' for latent dim or 'y' for the first decoder layer. Only applies when
use_mmd = TRUE
- mmd_boundary
On how many groups the MMD loss should be calculated on. If
mmd_boundary = NULL
(default), MMD is calculated on all groups. Only applies whenuse_mmd = TRUE
- beta
Scaling factor for MMD loss (1 by default). Only applies when
use_mmd = TRUE
- use_bn
Whether to apply a batch normalization to layers
- use_ln
Whether to apply a layer normalization to layers
- n_epochs
Maximum number of epochs to train the model
- lr
Learning rate for training
- eps
torch.optim.Adam
eps parameter to improve numerical stability (see here)- hide.py.warn
Disables some uninformative warnings from torch
- seed.use
An integer to generate reproducible outputs. Set
seed.use = NULL
to disable- verbose
Print messages. Set to
FALSE
to disable- ...
Additional arguments to be passed to
scarches.models.TRVAE.train
,TRVAE.load_query_data
orTRVAE.get_latent
(see Details section)
Value
A list containing:
Without surgery groups: a new DimReduc of name
new.reduction
(key set toreduction.key
) consisting of the latent space of the model withndims.out
dimensions.With surgery groups: one new DimReduc per surgery groups of name
new.reduction_[surgery.group]
(key set toreduction.key[surgery.group]
) consisting of the latent space of the corresponding models withndims.out
dimensions, as well as a 'full' latent representation of namenew.reduction_[surgery1]_[surgery2]_...
and key set toreduction.keyFull-
.
When called via IntegrateLayers
, a Seurat object with
the new reduction and/or assay is returned
Details
This wrappers calls three to four python functions through reticulate. Find the trVAE-specific arguments there:
model initiation: scarches.models.TRVAE
training: TRVAE.train, which relies on scarches.trainers.trvae.train.Trainer
post-training: scarches.models.base._base.CVAELatentsMixin.get_latent
surgery initiation: scarches.models.base._base.SurgeryMixin.load_query_data
Note that seed.use
is passed to torch.manual_seed()
.
If it is not sufficient to achieve full reproducibility, set
mean = TRUE
or mean_var = TRUE
Note
This function requires the scArches package to be installed (along with scipy)
References
Lotfollahi, M., Naghipourfar, M., Theis, F. J. & Wolf, F. A. Conditional out-of-distribution generation for unpaired data using transfer VAE. Bioinformatics 36, i610–i617 (2020). DOI
Lotfollahi, M., Naghipourfar, M., Luecken, M. D., Khajavi, M., Büttner, M., Wagenstetter, M., Avsec, Ž., Gayoso, A., Yosef, N., Interlandi, M., Rybakov, S., Misharin, A. V. & Theis, F. J. Mapping single-cell data to reference atlases by transfer learning. Nat Biotechnol 40, 121–130 (2021). DOI
Examples
if (FALSE) { # \dontrun{
# Preprocessing
obj <- SeuratData::LoadData("pbmcsca")
obj[["RNA"]] <- split(obj[["RNA"]], f = obj$Method)
obj <- NormalizeData(obj)
obj <- FindVariableFeatures(obj)
obj <- ScaleData(obj)
obj <- RunPCA(obj)
# After preprocessing, we integrate layers:
obj <- IntegrateLayers(object = obj, method = trVAEIntegration,
features = Features(obj), scale.layer = NULL,
layers = 'counts', groups = obj[[]],
groups.name = 'Method')
# To enable surgery and full reproducibility and change the recon loss method:
obj <- IntegrateLayers(object = obj, method = trVAEIntegration,
features = Features(obj), scale.layer = NULL,
layers = 'data', groups = obj[[]],
groups.name = 'Method', surgery.name = 'Experiment',
mean_var = TRUE, recon.loss = 'mse')
} # }