1. ### The LKJ prior vs the Wishart prior

2016-03-11
Source

(latest update : 2016-03-11 16:58:18)

As I noted at the end of this article, JAGS returned an overestimate of the between-standard deviation $\sigma_{b_2}$. This is how I simulated the data, with I=3 (number of groups) and J=4 (number of replicates per group):

simdata <- function(I, J){
set.seed(444)
### simulation of overall means ###
Mu.t1 <- 20
Mu.t2 <- 5
Mu <- c(Mu.t1, Mu.t2)
names(Mu) <- c("t1", "t2")
sigmab.t1 <-  8
sigmab.t2 <- 1
rho <- 0.2
Sigma <- rbind(
c(sigmab.t1^2, rho*sigmab.t1*sigmab.t2),
c(rho*sigmab.t1*sigmab.t2, sigmab.t2^2)
)
mu <- mvtnorm::rmvnorm(I, Mu, Sigma)
### simulation within-lots ###
sigmaw.t1 <- 2
sigmaw.t2 <- 0.5
y.t1 <- c(sapply(mu[,"t1"], function(m) rnorm(J, m, sigmaw.t1)))
y.t2 <- c(sapply(mu[,"t2"], function(m) rnorm(J, m, sigmaw.t2)))
### constructs the dataset ####
Timepoint <- rep(c("t1", "t2"), each=I*J)
Group <- paste0("grp", rep(gl(I,J), times=2))
Repeat <- rep(1:J, times=2*I)
dat <- data.frame(
Timepoint=Timepoint,
Group=Group,
Repeat=Repeat,
y=c(y.t1,y.t2)
)
dat$Timepoint <- relevel(dat$Timepoint, "t1")
return(dat)
}

Let us try JAGS on the data simulated with I=100:

dat <- simdata(I=100, J=4)

First note that the lme estimates are quite good:

library(nlme)
lme(y ~ Timepoint, data=dat, random= list(Group = pdSymm(~ 0+Timepoint )),
weights = varIdent(form = ~ Group:Timepoint | Timepoint) )
## Linear mixed-effects model fit by REML
##   Data: dat
##   Log-restricted-likelihood: -1491.099
##   Fixed: y ~ Timepoint
## (Intercept) Timepointt2
##    18.74254   -13.70368
##
## Random effects:
##  Formula: ~0 + Timepoint | Group
##  Structure: General positive-definite
##             StdDev   Corr
## Timepointt1 7.168704 Tmpnt1
## Timepointt2 1.087907 0.314
## Residual    1.926711
##
## Variance function:
##  Structure: Different standard deviations per stratum
##  Formula: ~Group:Timepoint | Timepoint
##  Parameter estimates:
##        t1        t2
## 1.0000000 0.2813927
## Number of Observations: 800
## Number of Groups: 100

Now let us run JAGS (see the previous article for the code not shown here):

library(rjags)
jagsmodel <- jags.model(jagsfile,
data = jagsdata,
inits = inits,
n.chains = length(inits))
update(jagsmodel, 5000) # warm-up
jagssamples <- coda.samples(jagsmodel,
c("Mu", "sigmaw1", "sigmaw2", "sigmab1", "sigmab2", "rhob"),
n.iter= 10000)

Below are the summary statistics of the posterior samples:

summary(jagssamples)
##
## Iterations = 5001:15000
## Thinning interval = 1
## Number of chains = 3
## Sample size per chain = 10000
##
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
##
##            Mean      SD  Naive SE Time-series SE
## Mu[1]   18.7307 0.73406 0.0042381      0.0045800
## Mu[2]    5.0384 0.15318 0.0008844      0.0009678
## rhob     0.2220 0.09633 0.0005562      0.0005876
## sigmab1  7.2574 0.52933 0.0030561      0.0031555
## sigmab2  1.5072 0.11030 0.0006368      0.0006672
## sigmaw1  1.9248 0.07806 0.0004507      0.0005689
## sigmaw2  0.5403 0.02192 0.0001266      0.0001617
##
## 2. Quantiles for each variable:
##
##             2.5%     25%     50%     75%   97.5%
## Mu[1]   17.27665 18.2403 18.7312 19.2174 20.1729
## Mu[2]    4.73784  4.9363  5.0383  5.1424  5.3386
## rhob     0.02904  0.1575  0.2243  0.2890  0.4042
## sigmab1  6.29967  6.8906  7.2245  7.5924  8.3763
## sigmab2  1.30734  1.4308  1.5012  1.5765  1.7407
## sigmaw1  1.78034  1.8711  1.9220  1.9754  2.0862
## sigmaw2  0.49930  0.5251  0.5395  0.5545  0.5854

Again, $\sigma_{b_2}$ is overestimated: its true value ($=1$) is less than the lower bound of the $95\%$-credible interval ($\approx 1.31$). The other estimates are quite good.

## Using the LKJ prior

The above problem is possibly due to the Wishart prior on the covariance matrix. Stan allows to use a LKJ prior on the correlation matrix. We will run it on the small dataset:

dat <- simdata(I=3, J=4)
dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group))
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
stancode <- 'data {
int<lower=1> N; // number of observations
real y[N]; // observations
int<lower=1> ngroups; // number of groups
int<lower=1> group[N]; // group indices
int<lower=1> timepoint[N]; // timepoint indices
}
parameters {
vector[2] Mu;
vector[2] mu[ngroups]; // group means
cholesky_factor_corr[2] L;
vector<lower=0>[2] sigma_b;
vector<lower=0>[2] sigma_w;
}
model {
sigma_w ~ cauchy(0, 5);
for(k in 1:N){
y[k] ~ normal(mu[group[k], timepoint[k]], sigma_w[timepoint[k]]);
}
sigma_b ~ cauchy(0, 5);
L ~ lkj_corr_cholesky(1);
Mu ~ normal(0, 25);
for(j in 1:ngroups){
mu[j] ~ multi_normal_cholesky(Mu, diag_pre_multiply(sigma_b, L));
}
}
generated quantities {
matrix[2,2] Omega;
matrix[2,2] Sigma;
real rho_b;
Omega <- multiply_lower_tri_self_transpose(L);
rho_b <- Sigma[1,2]/(sigma_b[1]*sigma_b[2]);
}'

### compile Stan model
stanmodel <- stan_model(model_code = stancode, model_name="stanmodel")

### Stan data
standata <- list(y=dat$y, N=nrow(dat), ngroups=nlevels(dat$Group),
timepoint=dat$timepoint, group=dat$group)

### Stan initial values
estimates <- function(dat, perturb=FALSE){
if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1) mu <- matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE)
Mu <- colMeans(mu)
sigma_b <- sqrt(diag(var(mu)))
L <- t(chol(cor(mu)))
sigmaw1 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=sd)$y) sigmaw2 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=sd)$y)
return(list(mu=mu, Mu=Mu, L=L, sigma_b=sigma_b, sigma_w = c(sigmaw1, sigmaw2)))
}
inits <- function(chain_id){
values <- estimates(dat, perturb = chain_id > 1)
return(values)
}

We are ready to run the Stan sampler. Following some messages when I firstly ran it with the default values of the control argument, I increase adapt_delta and max_treedepth:

### run Stan
stansamples <- sampling(stanmodel, data = standata, init=inits,
iter = 15000, warmup = 5000, chains = 4,

### outputs
library(coda)
codasamples <- do.call(mcmc.list,
plyr::alply(rstan::extract(stansamples, permuted=FALSE,
pars = c("Mu", "sigma_b", "sigma_w", "rho_b")),
2, mcmc))
summary(codasamples)
##
## Iterations = 1:10000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 10000
##
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
##
##               Mean     SD  Naive SE Time-series SE
## Mu[1]      13.8167 8.2269 0.0411345       0.069493
## Mu[2]       4.6488 0.8297 0.0041487       0.011824
## sigma_b[1] 14.0441 8.7839 0.0439194       0.078022
## sigma_b[2]  0.8254 1.2307 0.0061533       0.016818
## sigma_w[1]  2.0004 0.5599 0.0027996       0.004327
## sigma_w[2]  0.6394 0.1753 0.0008763       0.001268
## rho_b       0.1663 0.5171 0.0025855       0.003793
##
## 2. Quantiles for each variable:
##
##               2.5%     25%     50%     75%   97.5%
## Mu[1]      -4.6870  9.7532 14.3161 18.4474 29.0210
## Mu[2]       3.1807  4.4245  4.6577  4.8903  6.1337
## sigma_b[1]  5.6294  8.7652 11.6931 16.3625 36.9594
## sigma_b[2]  0.0319  0.2221  0.4600  0.9382  3.9622
## sigma_w[1]  1.2280  1.6105  1.8975  2.2683  3.3762
## sigma_w[2]  0.3984  0.5181  0.6067  0.7225  1.0733
## rho_b      -0.8399 -0.2404  0.2152  0.6100  0.9502

As compared to the JAGS estimates (given at the end), the estimates of $\sigma_{b_2}$ and $\rho_b$ obtained with Stan are really better. Note also the JAGS returned a huge credible interval for $\mu_2$.