Loading Web-Font TeX/Math/Italic
  1. The LKJ prior vs the Wishart prior

    2016-03-11
    Source

    (latest update : 2016-03-11 16:58:18)

    As I noted at the end of this article, JAGS returned an overestimate of the between-standard deviation \sigma_{b_2}. This is how I simulated the data, with I=3 (number of groups) and J=4 (number of replicates per group):

    1. simdata <- function(I, J){
    2. set.seed(444)
    3. ### simulation of overall means ###
    4. Mu.t1 <- 20
    5. Mu.t2 <- 5
    6. Mu <- c(Mu.t1, Mu.t2)
    7. names(Mu) <- c("t1", "t2")
    8. sigmab.t1 <- 8
    9. sigmab.t2 <- 1
    10. rho <- 0.2
    11. Sigma <- rbind(
    12. c(sigmab.t1^2, rho*sigmab.t1*sigmab.t2),
    13. c(rho*sigmab.t1*sigmab.t2, sigmab.t2^2)
    14. )
    15. mu <- mvtnorm::rmvnorm(I, Mu, Sigma)
    16. ### simulation within-lots ###
    17. sigmaw.t1 <- 2
    18. sigmaw.t2 <- 0.5
    19. y.t1 <- c(sapply(mu[,"t1"], function(m) rnorm(J, m, sigmaw.t1)))
    20. y.t2 <- c(sapply(mu[,"t2"], function(m) rnorm(J, m, sigmaw.t2)))
    21. ### constructs the dataset ####
    22. Timepoint <- rep(c("t1", "t2"), each=I*J)
    23. Group <- paste0("grp", rep(gl(I,J), times=2))
    24. Repeat <- rep(1:J, times=2*I)
    25. dat <- data.frame(
    26. Timepoint=Timepoint,
    27. Group=Group,
    28. Repeat=Repeat,
    29. y=c(y.t1,y.t2)
    30. )
    31. dat$Timepoint <- relevel(dat$Timepoint, "t1")
    32. return(dat)
    33. }

    Let us try JAGS on the data simulated with I=100:

    1. dat <- simdata(I=100, J=4)

    First note that the lme estimates are quite good:

    1. library(nlme)
    2. lme(y ~ Timepoint, data=dat, random= list(Group = pdSymm(~ 0+Timepoint )),
    3. weights = varIdent(form = ~ Group:Timepoint | Timepoint) )
    1. ## Linear mixed-effects model fit by REML
    2. ## Data: dat
    3. ## Log-restricted-likelihood: -1491.099
    4. ## Fixed: y ~ Timepoint
    5. ## (Intercept) Timepointt2
    6. ## 18.74254 -13.70368
    7. ##
    8. ## Random effects:
    9. ## Formula: ~0 + Timepoint | Group
    10. ## Structure: General positive-definite
    11. ## StdDev Corr
    12. ## Timepointt1 7.168704 Tmpnt1
    13. ## Timepointt2 1.087907 0.314
    14. ## Residual 1.926711
    15. ##
    16. ## Variance function:
    17. ## Structure: Different standard deviations per stratum
    18. ## Formula: ~Group:Timepoint | Timepoint
    19. ## Parameter estimates:
    20. ## t1 t2
    21. ## 1.0000000 0.2813927
    22. ## Number of Observations: 800
    23. ## Number of Groups: 100

    Now let us run JAGS (see the previous article for the code not shown here):

    1. library(rjags)
    2. jagsmodel <- jags.model(jagsfile,
    3. data = jagsdata,
    4. inits = inits,
    5. n.chains = length(inits))
    1. update(jagsmodel, 5000) # warm-up
    1. jagssamples <- coda.samples(jagsmodel,
    2. c("Mu", "sigmaw1", "sigmaw2", "sigmab1", "sigmab2", "rhob"),
    3. n.iter= 10000)

    Below are the summary statistics of the posterior samples:

    1. summary(jagssamples)
    1. ##
    2. ## Iterations = 5001:15000
    3. ## Thinning interval = 1
    4. ## Number of chains = 3
    5. ## Sample size per chain = 10000
    6. ##
    7. ## 1. Empirical mean and standard deviation for each variable,
    8. ## plus standard error of the mean:
    9. ##
    10. ## Mean SD Naive SE Time-series SE
    11. ## Mu[1] 18.7307 0.73406 0.0042381 0.0045800
    12. ## Mu[2] 5.0384 0.15318 0.0008844 0.0009678
    13. ## rhob 0.2220 0.09633 0.0005562 0.0005876
    14. ## sigmab1 7.2574 0.52933 0.0030561 0.0031555
    15. ## sigmab2 1.5072 0.11030 0.0006368 0.0006672
    16. ## sigmaw1 1.9248 0.07806 0.0004507 0.0005689
    17. ## sigmaw2 0.5403 0.02192 0.0001266 0.0001617
    18. ##
    19. ## 2. Quantiles for each variable:
    20. ##
    21. ## 2.5% 25% 50% 75% 97.5%
    22. ## Mu[1] 17.27665 18.2403 18.7312 19.2174 20.1729
    23. ## Mu[2] 4.73784 4.9363 5.0383 5.1424 5.3386
    24. ## rhob 0.02904 0.1575 0.2243 0.2890 0.4042
    25. ## sigmab1 6.29967 6.8906 7.2245 7.5924 8.3763
    26. ## sigmab2 1.30734 1.4308 1.5012 1.5765 1.7407
    27. ## sigmaw1 1.78034 1.8711 1.9220 1.9754 2.0862
    28. ## sigmaw2 0.49930 0.5251 0.5395 0.5545 0.5854

    Again, \sigma_{b_2} is overestimated: its true value (=1) is less than the lower bound of the 95\%-credible interval (\approx 1.31). The other estimates are quite good.

    Using the LKJ prior

    The above problem is possibly due to the Wishart prior on the covariance matrix. Stan allows to use a LKJ prior on the correlation matrix. We will run it on the small dataset:

    1. dat <- simdata(I=3, J=4)
    2. dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group))
    1. library(rstan)
    2. rstan_options(auto_write = TRUE)
    3. options(mc.cores = parallel::detectCores())
    4. stancode <- 'data {
    5. int<lower=1> N; // number of observations
    6. real y[N]; // observations
    7. int<lower=1> ngroups; // number of groups
    8. int<lower=1> group[N]; // group indices
    9. int<lower=1> timepoint[N]; // timepoint indices
    10. }
    11. parameters {
    12. vector[2] Mu;
    13. vector[2] mu[ngroups]; // group means
    14. cholesky_factor_corr[2] L;
    15. vector<lower=0>[2] sigma_b;
    16. vector<lower=0>[2] sigma_w;
    17. }
    18. model {
    19. sigma_w ~ cauchy(0, 5);
    20. for(k in 1:N){
    21. y[k] ~ normal(mu[group[k], timepoint[k]], sigma_w[timepoint[k]]);
    22. }
    23. sigma_b ~ cauchy(0, 5);
    24. L ~ lkj_corr_cholesky(1);
    25. Mu ~ normal(0, 25);
    26. for(j in 1:ngroups){
    27. mu[j] ~ multi_normal_cholesky(Mu, diag_pre_multiply(sigma_b, L));
    28. }
    29. }
    30. generated quantities {
    31. matrix[2,2] Omega;
    32. matrix[2,2] Sigma;
    33. real rho_b;
    34. Omega <- multiply_lower_tri_self_transpose(L);
    35. Sigma <- quad_form_diag(Omega, sigma_b);
    36. rho_b <- Sigma[1,2]/(sigma_b[1]*sigma_b[2]);
    37. }'
    38. ### compile Stan model
    39. stanmodel <- stan_model(model_code = stancode, model_name="stanmodel")
    40. ### Stan data
    41. standata <- list(y=dat$y, N=nrow(dat), ngroups=nlevels(dat$Group),
    42. timepoint=dat$timepoint, group=dat$group)
    43. ### Stan initial values
    44. estimates <- function(dat, perturb=FALSE){
    45. if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1)
    46. mu <- matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE)
    47. Mu <- colMeans(mu)
    48. sigma_b <- sqrt(diag(var(mu)))
    49. L <- t(chol(cor(mu)))
    50. sigmaw1 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=sd)$y)
    51. sigmaw2 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=sd)$y)
    52. return(list(mu=mu, Mu=Mu, L=L, sigma_b=sigma_b, sigma_w = c(sigmaw1, sigmaw2)))
    53. }
    54. inits <- function(chain_id){
    55. values <- estimates(dat, perturb = chain_id > 1)
    56. return(values)
    57. }

    We are ready to run the Stan sampler. Following some messages when I firstly ran it with the default values of the control argument, I increase adapt_delta and max_treedepth:

    1. ### run Stan
    2. stansamples <- sampling(stanmodel, data = standata, init=inits,
    3. iter = 15000, warmup = 5000, chains = 4,
    4. control=list(adapt_delta=0.999, max_treedepth=15))
    5. ### outputs
    6. library(coda)
    7. codasamples <- do.call(mcmc.list,
    8. plyr::alply(rstan::extract(stansamples, permuted=FALSE,
    9. pars = c("Mu", "sigma_b", "sigma_w", "rho_b")),
    10. 2, mcmc))
    11. summary(codasamples)
    1. ##
    2. ## Iterations = 1:10000
    3. ## Thinning interval = 1
    4. ## Number of chains = 4
    5. ## Sample size per chain = 10000
    6. ##
    7. ## 1. Empirical mean and standard deviation for each variable,
    8. ## plus standard error of the mean:
    9. ##
    10. ## Mean SD Naive SE Time-series SE
    11. ## Mu[1] 13.8167 8.2269 0.0411345 0.069493
    12. ## Mu[2] 4.6488 0.8297 0.0041487 0.011824
    13. ## sigma_b[1] 14.0441 8.7839 0.0439194 0.078022
    14. ## sigma_b[2] 0.8254 1.2307 0.0061533 0.016818
    15. ## sigma_w[1] 2.0004 0.5599 0.0027996 0.004327
    16. ## sigma_w[2] 0.6394 0.1753 0.0008763 0.001268
    17. ## rho_b 0.1663 0.5171 0.0025855 0.003793
    18. ##
    19. ## 2. Quantiles for each variable:
    20. ##
    21. ## 2.5% 25% 50% 75% 97.5%
    22. ## Mu[1] -4.6870 9.7532 14.3161 18.4474 29.0210
    23. ## Mu[2] 3.1807 4.4245 4.6577 4.8903 6.1337
    24. ## sigma_b[1] 5.6294 8.7652 11.6931 16.3625 36.9594
    25. ## sigma_b[2] 0.0319 0.2221 0.4600 0.9382 3.9622
    26. ## sigma_w[1] 1.2280 1.6105 1.8975 2.2683 3.3762
    27. ## sigma_w[2] 0.3984 0.5181 0.6067 0.7225 1.0733
    28. ## rho_b -0.8399 -0.2404 0.2152 0.6100 0.9502

    As compared to the JAGS estimates (given at the end), the estimates of \sigma_{b_2} and \rho_b obtained with Stan are really better. Note also the JAGS returned a huge credible interval for \mu_2.