# Call libraries
# For mixture models
library(mclust)
library(mixtools)
library(mdir)
# For visualisation
library(pheatmap)
library(ggplot2)
# For ``gather`` function
library(tidyr)
# For the pipe and some extensions
library(magrittr)
# Set seed for reproducibility
set.seed(1)
<- (grDevices::colorRampPalette(c("#146EB4", "white", "#FF9900")))(100)
col_pal
::theme_set(
ggplot2::theme_bw()
ggplot2+ ggplot2::theme(strip.background = ggplot2::element_rect(fill = "#21677e"))
+ ggplot2::theme(strip.text = ggplot2::element_text(colour = "white"))
)
This vignette explores clustering count data and the impact of some data transforms on this.
We will consider a \(\log\)-transform and the mean-centring and scaling (i.e., standardisation). For a vector of data \(X\), these are:
\[\begin{align} \text{Log-transform: } X \to \log(1 + X), \\ \text{Standardisation: } X \to \frac{X - \bar{X}}{\bar{\sigma}}, \end{align}\]
where \(\bar{X}\) is the empirical mean of \(X\), and \(\bar{\sigma}\) is the empirical standard deviation.
Generate count data
First we create 5 subpopulations with some peturbation about a mean.
# Example of transforms on poission data
<- 100
n <- c(1, 1, 5, 6, 2, 2)
beta_0 <- c(0.2, 0.2, 1, 2.5, 2, 2)
beta_1
# Generate random data
<- runif(n = n, min = 0, max = 2.0)
x
# Generate data from 4 poisson regression models
<- sapply(beta_1, `*`, x)
beta_1_mat <- t(apply(beta_1_mat, 1, `+`, t(beta_0)))
exponent_mat
<- apply(exp(exponent_mat), 2, function(x) {
poisson_data rpois(n = n, lambda = x)
})
# Put this data into a data.frame
<- data.frame(
poisson_df Count_1 = c(poisson_data[, 1], poisson_data[, 3]),
Count_2 = c(poisson_data[, 2], poisson_data[, 4]),
Count_3 = c(poisson_data[, 5], poisson_data[, 6])
)
Create some data that follows a sigmoidal curve:
# some continuous variables
<- runif(2 * n, -2, -1)
x1 <- runif(2 * n, 1, 2)
x2 <- runif(2 * n, -1, 1)
x3
# linear combination with a bias
<- 1 + 8 * x1 + 7.5 * x2 + 5 * x3
z
# pass through an inverse logit function and move to a scale similar to a count
<- round(1 / (1 + exp(-z)) * 1000)
pr
plot(z, pr, main = "Discere sigmoidal data")
hist(pr, main = "Discere sigmoidal data")
We add a feature generated from a different model; this follows a sigmoidal curve. We combine this with our previously generated data.
# Arbitrary step to have the high counts align in each dataset imperfectly but well
# enough to have less sub-populations emerge from the combined dataset
<- 8
fractions <- order(pr)
new_order <- as.numeric(cut(pr,
flag breaks = quantile(pr, probs = seq(0, 1, 1 / fractions)),
include.lowest = T,
labels = 1:fractions
))
%in% 3:4] <- fractions + 1
flag[flag
# Combine the generated data
<- as.data.frame(cbind(poisson_df, pr[order(flag)]))
my_data
# Assign row and column names
colnames(my_data) <- c(paste0("Count_", 1:(ncol(poisson_df) + 1)))
row.names(my_data) <- paste("Person_", 1:nrow(my_data))
# Of some use later
<- ncol(my_data)
n_var
head(my_data)
Count_1 Count_2 Count_3 Count_4
Person_ 1 4 2 18 0
Person_ 2 2 2 30 3
Person_ 3 2 3 89 11
Person_ 4 9 3 269 0
Person_ 5 3 1 23 2
Person_ 6 2 4 238 0
Now we apply our transforms.
# Log transform
<- log(1 + my_data) %>% as.data.frame()
log_data
# Mean centre and standardise the standard deviation within each variable
<- apply(my_data, 2, scale) %>%
scaled_data as.data.frame() |>
set_rownames(row.names(my_data))
# Let's now try combining these two
<- apply(log(1 + my_data), 2, scale) %>%
scaled_log_data as.data.frame() |>
set_rownames(row.names(my_data))
Let us look at the distributions described by each variable for each dataset. We expect there to be two subpopulations present under variables “Count_1”, “Count_2” and “Count_4”, and a single population under “Count_3”.
ggplot(gather(my_data), aes(value)) +
geom_histogram() +
facet_wrap(~key, scales = "free_x") +
labs(
title = "Distribution of original data",
x = "Count",
y = "Frequency"
)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
pheatmap(my_data,
color = col_pal,
main = "Generated data",
show_rownames = FALSE,
cluster_cols = FALSE
)
ggplot(gather(log_data), aes(value)) +
geom_histogram() +
facet_wrap(~key, scales = "free_x") +
labs(
title = "Distribution of log-transformed data",
x = "Count",
y = "Frequency"
)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
pheatmap(log_data,
color = col_pal,
main = "Log-transformed data",
show_rownames = FALSE,
cluster_cols = FALSE
)
ggplot(gather(scaled_data), aes(value)) +
geom_histogram() +
facet_wrap(~key, scales = "free_x") +
labs(
title = "Distribution of standardised data",
x = "Count",
y = "Frequency"
)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
pheatmap(scaled_data,
color = col_pal,
main = "Standardised data",
show_rownames = FALSE,
cluster_cols = FALSE
)
ggplot(gather(scaled_log_data), aes(value)) +
geom_histogram() +
facet_wrap(~key, scales = "free_x") +
labs(
title = "Distribution of standardised log-transformed data",
x = "Count",
y = "Frequency"
)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
pheatmap(scaled_log_data,
color = col_pal,
main = "Standardised log-transformed data",
show_rownames = FALSE,
cluster_cols = FALSE
)
Note that standardising has not impacted how well-separated the groups in the data are, it has merely changed the scale and location of the data.
In terms of separating out the subpopulations, it appears that the \(\log\)-transform has worked most effectively for the first two variables, but does not succeed with the sigmoidal data. This reminds us that a single transform might not be appropriate for an entire dataset; however, many datasets are too large to check each feature manually, this is simply to make the point that we often lose some signal and minimising this is only so feasible.
We will now attempt to infer the latent clustering labels using mixture models. I use the Mclust
function from the mclust
package and the mvnormalmixEM
from mixtools
to create Gaussian mixture models. We will look at the clustering predicted for the data using pheatmap
.
I create some labels for the sigmoidal data to keep track of the tails - we hope these are allocated correctly and are probably the hardest sub-populations to untangle for the \(\log\)-transformed data. These labels are based on which tertiles the sigmoidal data falls into and is intended as a rough guide of how well the models deconstruct the sigmoidal data.
# Keep track of the sigmoidal data by assigning a label based on quantiles
<- 3
n_labels <- cut(my_data[, n_var],
sig_labels breaks = quantile(my_data[, n_var], probs = seq(0, 1, 1 / n_labels)),
include.lowest = T,
labels = 1:n_labels
)
# Create a data.frame to annotate the heatmaps
<- data.frame(Sig_pop = as.factor(sig_labels))
annotation_row row.names(annotation_row) <- row.names(my_data)
Model fitting
Now attempt to fit models. I find that mixtools
is less robust to mclust
(struggles to solve the same datasets, even with the hclust
initialisation). For this reason I comment out the code for fitting the mixtools
model.
<- c(
model_functions "Mclust",
"mvnormalmixEM"
)
<- c(
transforms "Original",
"Log",
"Standardise"
)
<- list(my_data, log_data, scaled_data) %>%
datasets set_names(transforms)
<- length(datasets)
n_datasets <- length(model_functions)
n_models
<- vector("list", n_models) %>%
model_out set_names(model_functions)
<- model_out
model_bic
for (i in 1:n_models) {
<- vector("list", n_datasets) %>%
model_out[[i]] set_names(transforms)
<- model_out[[i]]
model_bic[[i]]
if (model_functions[i] == "Mclust") {
for (j in 1:n_datasets) {
# do.call(model_functions[i], list(datasets[[j]])
<- Mclust(datasets[[j]], G = 2:15)
model_out[[i]][[j]] <- mclustBIC(datasets[[j]])
model_bic[[i]][[j]]
}
}# if (model_functions[i] == "mvnormalmixEM") {
# for (j in 1:n_datasets) {
# for(k in seq(2, 15)) {
# initial_clusterings <- datasets[[j]] |>
# dist() |>
# hclust() |>
# cutree(k = k)
# initial_means <- vector("list", k)
# for(l in seq(1, k)) {
# cluster_inds <- which(initial_clusterings == l)
# initial_means[[l]] <- colMeans(datasets[[j]][cluster_inds, ])
# }
#
# model_out[[i]][[j]] <- mvnormalmixEM(datasets[[j]], k = k, mu = initial_means)
# }
# }
# }
}
We can now inspect the model using several different visualisations. We can investigate the optimal number of components under the Bayesian Information Criterion (BIC) and we can look at the clusterings as defined by pairs of variables. I will look at the models defined on the scaled data and the log
-transformed data. The BIC vs \(k\) plot also shows which of the possible types of covariance structure allowed by Mclust
is optimal (this is the difference models listed and plotted).
summary(model_out[[1]][[2]])
----------------------------------------------------
Gaussian finite mixture model fitted by EM algorithm
----------------------------------------------------
Mclust VVE (ellipsoidal, equal orientation) model with 8 components:
log-likelihood n df BIC ICL
-456.7326 200 77 -1321.436 -1331.303
Clustering table:
1 2 3 4 5 6 7 8
21 29 30 20 49 17 7 27
plot(model_out[[1]][[2]], what = "classification")
plot(model_out[[1]][[2]], what = "BIC")
summary(model_out[[1]][[3]])
----------------------------------------------------
Gaussian finite mixture model fitted by EM algorithm
----------------------------------------------------
Mclust VVV (ellipsoidal, varying volume, shape, and orientation) model with 8
components:
log-likelihood n df BIC ICL
990.4389 200 119 1350.378 1344.508
Clustering table:
1 2 3 4 5 6 7 8
20 30 25 25 23 24 23 30
plot(model_out[[1]][[3]], what = "classification")
plot(model_out[[1]][[3]], what = "BIC")
We can see that by using a vector of possible components (the input G = 2:15
in the call to Mclust
), we have captured the optimal value. This is the global maximum in the plot comparing BIC to number of components for each model allowed by Mclust
. One thing to notice here is that the two models do agree on the optimal number of components (8), but that this is not expected to occurr, particularly for datasets of greater dimension. Another point is that the type of model that best fits the scaled data is the “EVE” model in comparison to the “VVV” model. This is a simpler model where each cluster has more restrictions on its parameters - thus the EVE model is easier to run (this is worth remembering if log
-transformed data demands to complex a model computationally).
In the plots comparing clusterings across variables, we can see that the model defined on the log
-transformed data separates the two sub-populations in \(Count_1\) and \(Count_2\) with much greater confidence than the model defined on the standardised data. One can also see that \(Count_3\), which has no sub-population structure, is not contributing significantly to the cluster allocaitons - the data is pretty uniformly distributed across the axis defined by \(Count_3\), no clear partitions emerging. One can see that the sigmoidal structure in \(Count_4\) is captured quie well by the log
model.
Results
Let us now inspect the clustering inferred I ignore the original data as the scaling makes inspecting the data unfeasible.
<- lapply(model_out$Mclust, function(x) {
labelling $classification
x|>
}) as.data.frame() |>
set_rownames(row.names(my_data))
<- labelling
annotation_row
pheatmap(log_data[order(annotation_row[, 1]), ],
cluster_rows = F,
cluster_cols = F,
annotation_row = annotation_row,
main = "Log-transformed data:\nOrdered by clustering of original data"
)
pheatmap(log_data[order(annotation_row[, 2]), ],
cluster_rows = F,
cluster_cols = F,
annotation_row = annotation_row,
main = "Log-transformed data:\nOrdered by clustering of log-transformed data"
)
pheatmap(log_data[order(annotation_row[, 3]), ],
cluster_rows = F,
cluster_cols = F,
annotation_row = annotation_row,
main = "Log-transformed data:\nOrdered by clustering of scaled data"
)
pheatmap(scaled_data[order(annotation_row[, 1]), ],
cluster_rows = F,
cluster_cols = F,
annotation_row = annotation_row,
main = "Standardised data:\nOrdered by clustering of original data"
)
pheatmap(scaled_data[order(annotation_row[, 2]), ],
cluster_rows = F,
cluster_cols = F,
annotation_row = annotation_row,
main = "Standardised data:\nOrdered by clustering of log-transformed data"
)
pheatmap(scaled_data[order(annotation_row[, 3]), ],
cluster_rows = F,
cluster_cols = F,
annotation_row = annotation_row,
main = "Standardised data:\nOrdered by clustering of scaled data"
)
We comapre the similarity of the inferred clusterings using the Adjusted Rand Index. This scores clustering similarities, with 0 meaning the two partitions are no more similar then a random pair of clusterings is expected to be and 1 meaning they are identical.
<- mcclust::arandi(annotation_row[, 1], annotation_row[, 2])
ari_12 <- mcclust::arandi(annotation_row[, 1], annotation_row[, 3])
ari_13 <- mcclust::arandi(annotation_row[, 2], annotation_row[, 3])
ari_23
<- matrix(c(1.0, ari_12, ari_13, ari_12, 1.0, ari_23, ari_13, ari_23, 1.0),
ari_mat nrow = 3,
ncol = 3
|>
) set_colnames(c("Original", "Log-transformed", "Standardised")) |>
set_rownames(c("Original", "Log-transformed", "Standardised"))
::kable(ari_mat, digits = 3) knitr
Original | Log-transformed | Standardised | |
---|---|---|---|
Original | 1.000 | 0.702 | 1.000 |
Log-transformed | 0.702 | 1.000 | 0.702 |
Standardised | 1.000 | 0.702 | 1.000 |
Summary
We see that from the Gaussian mixture models perspective the original and standardised are interchangeable, leading to identical inference. Both datasets lead to a surprisingly similar clustering to the log-transformed data; 0.7 is not a low ARI. Furthermore, we can see in the annotated heatmaps that a large source of the contention is that one cluster found in the log-transformed data is considered as two separate groups in the original (cluster 6 in the log-transformed inference approximately captures clusters 6 and 5 from the other point estimates), and conversely cluster 8 in the original data splits into two clusters in the log-transformed data. Deciding which of these is more useful, or if we should use 7 or 9 clusters rather then 8 involves further thought and ideally conversation with a domain expert.
Reuse
Citation
@online{coleman,
author = {Stephen Coleman},
title = {Clustering Count Data},
url = {https://github.com/stcolema/stcolema.github.io/posts/ClusteringCountData/clustering_count_data.html},
langid = {en}
}