1. The LKJ prior vs the Wishart prior

    2016-03-11
    Source

    (latest update : 2016-03-11 16:58:18)

    As I noted at the end of this article, JAGS returned an overestimate of the between-standard deviation \(\sigma_{b_2}\). This is how I simulated the data, with I=3 (number of groups) and J=4 (number of replicates per group):

    simdata <- function(I, J){
      set.seed(444) 
      ### simulation of overall means ###
      Mu.t1 <- 20
      Mu.t2 <- 5
      Mu <- c(Mu.t1, Mu.t2)
      names(Mu) <- c("t1", "t2")
      sigmab.t1 <-  8
      sigmab.t2 <- 1
      rho <- 0.2
      Sigma <- rbind(
        c(sigmab.t1^2, rho*sigmab.t1*sigmab.t2),
        c(rho*sigmab.t1*sigmab.t2, sigmab.t2^2)
      )
      mu <- mvtnorm::rmvnorm(I, Mu, Sigma)
      ### simulation within-lots ###
      sigmaw.t1 <- 2
      sigmaw.t2 <- 0.5
      y.t1 <- c(sapply(mu[,"t1"], function(m) rnorm(J, m, sigmaw.t1)))
      y.t2 <- c(sapply(mu[,"t2"], function(m) rnorm(J, m, sigmaw.t2)))
      ### constructs the dataset ####
      Timepoint <- rep(c("t1", "t2"), each=I*J)
      Group <- paste0("grp", rep(gl(I,J), times=2))
      Repeat <- rep(1:J, times=2*I) 
      dat <- data.frame(
        Timepoint=Timepoint,
        Group=Group,
        Repeat=Repeat, 
        y=c(y.t1,y.t2)
      )
      dat$Timepoint <- relevel(dat$Timepoint, "t1") 
      return(dat)
    }

    Let us try JAGS on the data simulated with I=100:

    dat <- simdata(I=100, J=4)

    First note that the lme estimates are quite good:

    library(nlme)
    lme(y ~ Timepoint, data=dat, random= list(Group = pdSymm(~ 0+Timepoint )), 
        weights = varIdent(form = ~ Group:Timepoint | Timepoint) )
    ## Linear mixed-effects model fit by REML
    ##   Data: dat 
    ##   Log-restricted-likelihood: -1491.099
    ##   Fixed: y ~ Timepoint 
    ## (Intercept) Timepointt2 
    ##    18.74254   -13.70368 
    ## 
    ## Random effects:
    ##  Formula: ~0 + Timepoint | Group
    ##  Structure: General positive-definite
    ##             StdDev   Corr  
    ## Timepointt1 7.168704 Tmpnt1
    ## Timepointt2 1.087907 0.314 
    ## Residual    1.926711       
    ## 
    ## Variance function:
    ##  Structure: Different standard deviations per stratum
    ##  Formula: ~Group:Timepoint | Timepoint 
    ##  Parameter estimates:
    ##        t1        t2 
    ## 1.0000000 0.2813927 
    ## Number of Observations: 800
    ## Number of Groups: 100

    Now let us run JAGS (see the previous article for the code not shown here):

    library(rjags)
    jagsmodel <- jags.model(jagsfile,
                       data = jagsdata, 
                       inits = inits, 
                       n.chains = length(inits))
    update(jagsmodel, 5000) # warm-up
    jagssamples <- coda.samples(jagsmodel, 
                c("Mu", "sigmaw1", "sigmaw2", "sigmab1", "sigmab2", "rhob"), 
                n.iter= 10000)

    Below are the summary statistics of the posterior samples:

    summary(jagssamples)
    ## 
    ## Iterations = 5001:15000
    ## Thinning interval = 1 
    ## Number of chains = 3 
    ## Sample size per chain = 10000 
    ## 
    ## 1. Empirical mean and standard deviation for each variable,
    ##    plus standard error of the mean:
    ## 
    ##            Mean      SD  Naive SE Time-series SE
    ## Mu[1]   18.7307 0.73406 0.0042381      0.0045800
    ## Mu[2]    5.0384 0.15318 0.0008844      0.0009678
    ## rhob     0.2220 0.09633 0.0005562      0.0005876
    ## sigmab1  7.2574 0.52933 0.0030561      0.0031555
    ## sigmab2  1.5072 0.11030 0.0006368      0.0006672
    ## sigmaw1  1.9248 0.07806 0.0004507      0.0005689
    ## sigmaw2  0.5403 0.02192 0.0001266      0.0001617
    ## 
    ## 2. Quantiles for each variable:
    ## 
    ##             2.5%     25%     50%     75%   97.5%
    ## Mu[1]   17.27665 18.2403 18.7312 19.2174 20.1729
    ## Mu[2]    4.73784  4.9363  5.0383  5.1424  5.3386
    ## rhob     0.02904  0.1575  0.2243  0.2890  0.4042
    ## sigmab1  6.29967  6.8906  7.2245  7.5924  8.3763
    ## sigmab2  1.30734  1.4308  1.5012  1.5765  1.7407
    ## sigmaw1  1.78034  1.8711  1.9220  1.9754  2.0862
    ## sigmaw2  0.49930  0.5251  0.5395  0.5545  0.5854

    Again, \(\sigma_{b_2}\) is overestimated: its true value (\(=1\)) is less than the lower bound of the \(95\%\)-credible interval (\(\approx 1.31\)). The other estimates are quite good.

    Using the LKJ prior

    The above problem is possibly due to the Wishart prior on the covariance matrix. Stan allows to use a LKJ prior on the correlation matrix. We will run it on the small dataset:

    dat <- simdata(I=3, J=4)
    dat <- transform(dat, timepoint=as.integer(Timepoint), group=as.integer(Group))
    library(rstan)
    rstan_options(auto_write = TRUE)
    options(mc.cores = parallel::detectCores())
    stancode <- 'data {
      int<lower=1> N; // number of observations
      real y[N]; // observations 
      int<lower=1> ngroups; // number of groups
      int<lower=1> group[N]; // group indices
      int<lower=1> timepoint[N]; // timepoint indices
    }
    parameters {
      vector[2] Mu;
      vector[2] mu[ngroups]; // group means
      cholesky_factor_corr[2] L; 
      vector<lower=0>[2] sigma_b; 
      vector<lower=0>[2] sigma_w; 
    }
    model {
      sigma_w ~ cauchy(0, 5);
      for(k in 1:N){
        y[k] ~ normal(mu[group[k], timepoint[k]], sigma_w[timepoint[k]]);
      }
      sigma_b ~ cauchy(0, 5);
      L ~ lkj_corr_cholesky(1);
      Mu ~ normal(0, 25);
      for(j in 1:ngroups){
        mu[j] ~ multi_normal_cholesky(Mu, diag_pre_multiply(sigma_b, L));
      }
    }
    generated quantities {
      matrix[2,2] Omega;
      matrix[2,2] Sigma;
      real rho_b;
      Omega <- multiply_lower_tri_self_transpose(L);
      Sigma <- quad_form_diag(Omega, sigma_b); 
      rho_b <- Sigma[1,2]/(sigma_b[1]*sigma_b[2]);
    }'
    
    ### compile Stan model
    stanmodel <- stan_model(model_code = stancode, model_name="stanmodel")
    
    ### Stan data
    standata <- list(y=dat$y, N=nrow(dat), ngroups=nlevels(dat$Group),  
                     timepoint=dat$timepoint, group=dat$group)
    
    ### Stan initial values
    estimates <- function(dat, perturb=FALSE){
      if(perturb) dat$y <- dat$y + rnorm(length(dat$y), 0, 1)
      mu <-  matrix(aggregate(y~timepoint:group, data=dat, FUN=mean)$y, ncol=2, byrow=TRUE)
      Mu <- colMeans(mu)
      sigma_b <- sqrt(diag(var(mu)))
      L <- t(chol(cor(mu)))
      sigmaw1 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t1"), FUN=sd)$y)
      sigmaw2 <- mean(aggregate(y~Group, data=subset(dat, Timepoint=="t2"), FUN=sd)$y)
      return(list(mu=mu, Mu=Mu, L=L, sigma_b=sigma_b, sigma_w = c(sigmaw1, sigmaw2)))
    }
    inits <- function(chain_id){
      values <- estimates(dat, perturb = chain_id > 1)
      return(values)
    }

    We are ready to run the Stan sampler. Following some messages when I firstly ran it with the default values of the control argument, I increase adapt_delta and max_treedepth:

    ### run Stan
    stansamples <- sampling(stanmodel, data = standata, init=inits, 
                        iter = 15000, warmup = 5000, chains = 4,
                        control=list(adapt_delta=0.999, max_treedepth=15))
    
    ### outputs
    library(coda)
    codasamples <- do.call(mcmc.list, 
      plyr::alply(rstan::extract(stansamples, permuted=FALSE, 
                    pars = c("Mu", "sigma_b", "sigma_w", "rho_b")), 
                  2, mcmc))
    summary(codasamples)
    ## 
    ## Iterations = 1:10000
    ## Thinning interval = 1 
    ## Number of chains = 4 
    ## Sample size per chain = 10000 
    ## 
    ## 1. Empirical mean and standard deviation for each variable,
    ##    plus standard error of the mean:
    ## 
    ##               Mean     SD  Naive SE Time-series SE
    ## Mu[1]      13.8167 8.2269 0.0411345       0.069493
    ## Mu[2]       4.6488 0.8297 0.0041487       0.011824
    ## sigma_b[1] 14.0441 8.7839 0.0439194       0.078022
    ## sigma_b[2]  0.8254 1.2307 0.0061533       0.016818
    ## sigma_w[1]  2.0004 0.5599 0.0027996       0.004327
    ## sigma_w[2]  0.6394 0.1753 0.0008763       0.001268
    ## rho_b       0.1663 0.5171 0.0025855       0.003793
    ## 
    ## 2. Quantiles for each variable:
    ## 
    ##               2.5%     25%     50%     75%   97.5%
    ## Mu[1]      -4.6870  9.7532 14.3161 18.4474 29.0210
    ## Mu[2]       3.1807  4.4245  4.6577  4.8903  6.1337
    ## sigma_b[1]  5.6294  8.7652 11.6931 16.3625 36.9594
    ## sigma_b[2]  0.0319  0.2221  0.4600  0.9382  3.9622
    ## sigma_w[1]  1.2280  1.6105  1.8975  2.2683  3.3762
    ## sigma_w[2]  0.3984  0.5181  0.6067  0.7225  1.0733
    ## rho_b      -0.8399 -0.2404  0.2152  0.6100  0.9502

    As compared to the JAGS estimates (given at the end), the estimates of \(\sigma_{b_2}\) and \(\rho_b\) obtained with Stan are really better. Note also the JAGS returned a huge credible interval for \(\mu_2\).