-
The LKJ prior vs the Wishart prior
2016-03-11
Source(latest update : 2016-03-11 16:58:18)
As I noted at the end of this article, JAGS returned an overestimate of the between-standard deviation \(\sigma_{b_2}\). This is how I simulated the data, with
I=3
(number of groups) andJ=4
(number of replicates per group):simdata <- function(I, J){ set.seed(444) ### simulation of overall means ### Mu.t1 <- 20 Mu.t2 <- 5 Mu <- c(Mu.t1, Mu.t2) names(Mu) <- c("t1", "t2") sigmab.t1 <- 8 sigmab.t2 <- 1 rho <- 0.2 Sigma <- rbind( c(sigmab.t1^2, rho*sigmab.t1*sigmab.t2), c(rho*sigmab.t1*sigmab.t2, sigmab.t2^2) ) mu <- mvtnorm::rmvnorm(I, Mu, Sigma) ### simulation within-lots ### sigmaw.t1 <- 2 sigmaw.t2 <- 0.5 y.t1 <- c(sapply(mu[,"t1"], function(m) rnorm(J, m, sigmaw.t1))) y.t2 <- c(sapply(mu[,"t2"], function(m) rnorm(J, m, sigmaw.t2))) ### constructs the dataset #### Timepoint <- rep(c("t1", "t2"), each=I*J) Group <- paste0("grp", rep(gl(I,J), times=2)) Repeat <- rep(1:J, times=2*I) dat <- data.frame( Timepoint=Timepoint, Group=Group, Repeat=Repeat, y=c(y.t1,y.t2) ) dat$Timepoint <- relevel(dat$Timepoint, "t1") return(dat) }
Let us try JAGS on the data simulated with
I=100
:dat <- simdata(I=100, J=4)
First note that the
lme
estimates are quite good:library(nlme) lme(y ~ Timepoint, data=dat, random= list(Group = pdSymm(~ 0+Timepoint )), weights = varIdent(form = ~ Group:Timepoint | Timepoint) )
## Linear mixed-effects model fit by REML ## Data: dat ## Log-restricted-likelihood: -1491.099 ## Fixed: y ~ Timepoint ## (Intercept) Timepointt2 ## 18.74254 -13.70368 ## ## Random effects: ## Formula: ~0 + Timepoint | Group ## Structure: General positive-definite ## StdDev Corr ## Timepointt1 7.168704 Tmpnt1 ## Timepointt2 1.087907 0.314 ## Residual 1.926711 ## ## Variance function: ## Structure: Different standard deviations per stratum ## Formula: ~Group:Timepoint | Timepoint ## Parameter estimates: ## t1 t2 ## 1.0000000 0.2813927 ## Number of Observations: 800 ## Number of Groups: 100
Now let us run JAGS (see the previous article for the code not shown here):
library(rjags) jagsmodel <- jags.model(jagsfile, data = jagsdata, inits = inits, n.chains = length(inits))
update(jagsmodel, 5000) # warm-up
jagssamples <- coda.samples(jagsmodel, c("Mu", "sigmaw1", "sigmaw2", "sigmab1", "sigmab2", "rhob"), n.iter= 10000)
Below are the summary statistics of the posterior samples:
summary(jagssamples)
## ## Iterations = 5001:15000 ## Thinning interval = 1 ## Number of chains = 3 ## Sample size per chain = 10000 ## ## 1. Empirical mean and standard deviation for each variable, ## plus standard error of the mean: ## ## Mean SD Naive SE Time-series SE ## Mu[1] 18.7307 0.73406 0.0042381 0.0045800 ## Mu[2] 5.0384 0.15318 0.0008844 0.0009678 ## rhob 0.2220 0.09633 0.0005562 0.0005876 ## sigmab1 7.2574 0.52933 0.0030561 0.0031555 ## sigmab2 1.5072 0.11030 0.0006368 0.0006672 ## sigmaw1 1.9248 0.07806 0.0004507 0.0005689 ## sigmaw2 0.5403 0.02192 0.0001266 0.0001617 ## ## 2. Quantiles for each variable: ## ## 2.5% 25% 50% 75% 97.5% ## Mu[1] 17.27665 18.2403 18.7312 19.2174 20.1729 ## Mu[2] 4.73784 4.9363 5.0383 5.1424 5.3386 ## rhob 0.02904 0.1575 0.2243 0.2890 0.4042 ## sigmab1 6.29967 6.8906 7.2245 7.5924 8.3763 ## sigmab2 1.30734 1.4308 1.5012 1.5765 1.7407 ## sigmaw1 1.78034 1.8711 1.9220 1.9754 2.0862 ## sigmaw2 0.49930 0.5251 0.5395 0.5545 0.5854
Again, \(\sigma_{b_2}\) is overestimated: its true value (\(=1\)) is less than the lower bound of the \(95\%\)-credible interval (\(\approx 1.31\)). The other estimates are quite good.
Using the LKJ prior
The above problem is possibly due to the Wishart prior on the covariance matrix. Stan allows to use a LKJ prior on the correlation matrix. We will run it on the small dataset:
dat <- simdata(I=3, J=4) dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group))
library(rstan) rstan_options(auto_write = TRUE) options(mc.cores = parallel::detectCores()) stancode <- 'data { int<lower=1> N; // number of observations real y[N]; // observations int<lower=1> ngroups; // number of groups int<lower=1> group[N]; // group indices int<lower=1> timepoint[N]; // timepoint indices } parameters { vector[2] Mu; vector[2] mu[ngroups]; // group means cholesky_factor_corr[2] L; vector<lower=0>[2] sigma_b; vector<lower=0>[2] sigma_w; } model { sigma_w ~ cauchy(0, 5); for(k in 1:N){ y[k] ~ normal(mu[group[k], timepoint[k]], sigma_w[timepoint[k]]); } sigma_b ~ cauchy(0, 5); L ~ lkj_corr_cholesky(1); Mu ~ normal(0, 25); for(j in 1:ngroups){ mu[j] ~ multi_normal_cholesky(Mu, diag_pre_multiply(sigma_b, L)); } } generated quantities { matrix[2,2] Omega; matrix[2,2] Sigma; real rho_b; Omega <- multiply_lower_tri_self_transpose(L); Sigma <- quad_form_diag(Omega, sigma_b); rho_b <- Sigma[1,2]/(sigma_b[1]*sigma_b[2]); }' ### compile Stan model stanmodel <- stan_model(model_code = stancode, model_name="stanmodel") ### Stan data standata <- list(y=dat$y, N=nrow(dat), ngroups=nlevels(dat$Group), timepoint=dat$timepoint, group=dat$group) ### Stan initial values estimates <- function(dat, perturb=FALSE){ if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1) mu <- matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE) Mu <- colMeans(mu) sigma_b <- sqrt(diag(var(mu))) L <- t(chol(cor(mu))) sigmaw1 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=sd)$y) sigmaw2 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=sd)$y) return(list(mu=mu, Mu=Mu, L=L, sigma_b=sigma_b, sigma_w = c(sigmaw1, sigmaw2))) } inits <- function(chain_id){ values <- estimates(dat, perturb = chain_id > 1) return(values) }
We are ready to run the Stan sampler. Following some messages when I firstly ran it with the default values of the
control
argument, I increaseadapt_delta
andmax_treedepth
:### run Stan stansamples <- sampling(stanmodel, data = standata, init=inits, iter = 15000, warmup = 5000, chains = 4, control=list(adapt_delta=0.999, max_treedepth=15)) ### outputs library(coda) codasamples <- do.call(mcmc.list, plyr::alply(rstan::extract(stansamples, permuted=FALSE, pars = c("Mu", "sigma_b", "sigma_w", "rho_b")), 2, mcmc)) summary(codasamples)
## ## Iterations = 1:10000 ## Thinning interval = 1 ## Number of chains = 4 ## Sample size per chain = 10000 ## ## 1. Empirical mean and standard deviation for each variable, ## plus standard error of the mean: ## ## Mean SD Naive SE Time-series SE ## Mu[1] 13.8167 8.2269 0.0411345 0.069493 ## Mu[2] 4.6488 0.8297 0.0041487 0.011824 ## sigma_b[1] 14.0441 8.7839 0.0439194 0.078022 ## sigma_b[2] 0.8254 1.2307 0.0061533 0.016818 ## sigma_w[1] 2.0004 0.5599 0.0027996 0.004327 ## sigma_w[2] 0.6394 0.1753 0.0008763 0.001268 ## rho_b 0.1663 0.5171 0.0025855 0.003793 ## ## 2. Quantiles for each variable: ## ## 2.5% 25% 50% 75% 97.5% ## Mu[1] -4.6870 9.7532 14.3161 18.4474 29.0210 ## Mu[2] 3.1807 4.4245 4.6577 4.8903 6.1337 ## sigma_b[1] 5.6294 8.7652 11.6931 16.3625 36.9594 ## sigma_b[2] 0.0319 0.2221 0.4600 0.9382 3.9622 ## sigma_w[1] 1.2280 1.6105 1.8975 2.2683 3.3762 ## sigma_w[2] 0.3984 0.5181 0.6067 0.7225 1.0733 ## rho_b -0.8399 -0.2404 0.2152 0.6100 0.9502
As compared to the JAGS estimates (given at the end), the estimates of \(\sigma_{b_2}\) and \(\rho_b\) obtained with Stan are really better. Note also the JAGS returned a huge credible interval for \(\mu_2\).