batchmix: Bayesian mixture modelling for multi-batch data

Bayes
Author
Affiliation

My R package, batchmix, finally landed on CRAN. This implements the models described in my paper (Coleman et al. 2022), mixture models aimed at accounting for batch effects.

The model is very similar to the ComBat algorithm of Johnson, Li, and Rabinovic (2007), but infers batch effects and class at the same time. This makes the method semi-supervised as it uses the inferred labels to update class parameters. This offers some protection from unequal representation of class across batches which is nice.

This post will essentially recreate the vignette for the package showing how to use it.

Data generation

We simulate some data using the generateBatchData function.

library(ggplot2)
library(batchmix)
library(mdiHelpR)

Attaching package: 'mdiHelpR'
The following object is masked from 'package:batchmix':

    createSimilarityMat
The following object is masked from 'package:methods':

    show
set.seed(1)
setMyTheme()

# Data dimensions
N <- 600
P <- 4
K <- 5
B <- 7

# Generating model parameters
mean_dist <- 2.25
batch_dist <- 0.3
group_means <- seq(1, K) * mean_dist
batch_shift <- rnorm(B, mean = batch_dist, sd = batch_dist)
std_dev <- rep(2, K)
batch_var <- rep(1.2, B)
group_weights <- rep(1 / K, K)
batch_weights <- rep(1 / B, B)
dfs <- c(5, 10, 15, 20, 80)

my_data <- generateBatchData(
  N,
  P,
  group_means,
  std_dev,
  batch_shift,
  batch_var,
  group_weights,
  batch_weights,
  type = "MVT",
  group_dfs = dfs
)

This gives us a named list with two related datasets, the observed_data which includes batch effects and the corrected_data which is batch-free. It also includes group_IDs, a vector indicating class membership for each item, batch_IDs, which indicates batch of origin for each item, and fixed, which indicates which labels are observed and fixed in the model. We pull these out of the names list in the format that the modelling functions desire them.

X <- my_data$observed_data

true_labels <- my_data$group_IDs
fixed <- my_data$fixed
batch_vec <- my_data$batch_IDs

alpha <- 1
initial_labels <- generateInitialLabels(alpha, K, fixed, true_labels)

Modelling

Given some data, we are interested in modelling it. We assume here that the set of observed labels includes at least one example of each class in the data.

# Sampling parameters
R <- 1000
thin <- 50
n_chains <- 4

# Density choice
type <- "MVT"

# MCMC samples and BIC vector
mcmc_output <- runMCMCChains(
  X,
  n_chains,
  R,
  thin,
  batch_vec,
  type,
  initial_labels = initial_labels,
  fixed = fixed
)

We want to assess two things. First, how frequently the proposed parameters in the Metropolis-Hastings step are accepted:

plotAcceptanceRates(mcmc_output)

Secondly, we want to asses how well our chains have converged. To do this we plot the complete_likelihood of each chain. This is the quantity most relevant to a clustering/classification, being dependent on the labels. The observed_likelihood is independent of labels and more relevant for density estimation.

plotLikelihoods(mcmc_output)

We see that our chains disagree. We have to run them for more iterations. We use the continueChains function for this.

R_new <- 9000

# Given an initial value for the parameters
new_output <- continueChains(
  mcmc_output,
  X,
  fixed,
  batch_vec,
  R_new,
  keep_old_samples = TRUE
)

To see if the chains better agree we re-plot the likelihood.

plotLikelihoods(new_output)

We also re-check the acceptance rates.

plotAcceptanceRates(new_output)

This looks like several of the chains agree by the 5,000th iteration.

Process chains

We process the chains, acquiring point estimates of different quantities.

# Burn in
burn <- 5000

# Process the MCMC samples
processed_samples <- processMCMCChains(new_output, burn)

Visualisation

For multidimensional data we use a PCA plot.

chain_used <- processed_samples[[1]]

pc <- prcomp(X, scale = T)
pc_batch_corrected <- prcomp(chain_used$inferred_dataset)

plot_df <- data.frame(
  PC1 = pc$x[, 1],
  PC2 = pc$x[, 2],
  PC1_bf = pc_batch_corrected$x[, 1],
  PC2_bf = pc_batch_corrected$x[, 2],
  pred_labels = factor(chain_used$pred),
  true_labels = factor(true_labels),
  prob = chain_used$prob,
  batch = factor(batch_vec)
)

plot_df |> 
  ggplot(aes(
    x = PC1,
    y = PC2,
    colour = true_labels,
    alpha = prob
  )) +
  geom_point()

plot_df |> 
  ggplot(aes(
    x = PC1_bf,
    y = PC2_bf,
    colour = pred_labels,
    alpha = prob
  )) +
  geom_point()

test_inds <- which(fixed == 0)

sum(true_labels[test_inds] == chain_used$pred[test_inds])/length(test_inds)
[1] 0.8404255

References

Coleman, Stephen, Xaquin Castro Dopico, Gunilla B. Karlsson Hedestam, Paul D. W. Kirk, and Chris Wallace. 2022. “A Semi-Supervised Bayesian Mixture Modelling Approach for Joint Batch Correction and Classification.” bioRxiv. https://doi.org/10.1101/2022.01.14.476352.
Johnson, W. Evan, Cheng Li, and Ariel Rabinovic. 2007. “Adjusting Batch Effects in Microarray Expression Data Using Empirical Bayes Methods.” Biostatistics 8 (1): 118–27. https://doi.org/10.1093/biostatistics/kxj037.

Reuse

Citation

BibTeX citation:
@online{coleman,
  author = {Stephen Coleman},
  title = {Batchmix: {Bayesian} Mixture Modelling for Multi-Batch Data},
  url = {https://github.com/stcolema/stcolema.github.io/posts/batchmix},
  langid = {en}
}
For attribution, please cite this work as:
Stephen Coleman. n.d. “Batchmix: Bayesian Mixture Modelling for Multi-Batch Data.” https://github.com/stcolema/stcolema.github.io/posts/batchmix.