1. ### Using a LKJ prior in Stan

2016-03-07
Source

There are two ways to use a LKJ prior distribution for a correlation matrix in STAN. The first one assigns the distribution on the correlation matrix, whereas the second one assigns the distribution on the lower Cholesky factor of the correlation matrix. I am going to show an example for a trivariate normal sample with a fixed mean: $y_i \sim_{\text{iid}} {\cal N}\left( \begin{pmatrix} 0 \\ 0 \\ 0 \end{pmatrix}, \Sigma\right).$ Recall the relation between the covariance matrix and the correlation matrix: $\begin{pmatrix} \sigma_{1}^2 & \sigma_{12} & \sigma_{13} \\ \sigma_{12} & \sigma_2^2 & \sigma_{23} \\ \sigma_{13} & \sigma_{23} & \sigma_3^2 \end{pmatrix} = \begin{pmatrix} \sigma_{1}^2 & 0 & 0 \\ 0 & \sigma_2^2 & 0 \\ 0 & 0 & \sigma_3^2 \end{pmatrix} \Omega \begin{pmatrix} \sigma_{1}^2 & 0 & 0 \\ 0 & \sigma_2^2 & 0 \\ 0 & 0 & \sigma_3^2 \end{pmatrix}$ This operation is performed in Stan by the function quad_form_diag.

I do not assume the reader familiar with Stan or the rstan package, so I will comment each step.

First I load the rstan package with the usual options:

library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

I will run Stan on these simulated data:

set.seed(666)
Omega <- rbind(
c(1, 0.3, 0.2),
c(0.3, 1, 0.1),
c(0.2, 0.1, 1)
)
sigma <- c(1, 2, 3)
Sigma <- diag(sigma) %*% Omega %*% diag(sigma)
N <- 100
y <- mvtnorm::rmvnorm(N, c(0,0,0), Sigma)

## Prior on the correlation matrix

Below is the Stan code for the Bayesian model assigning a LKJ prior on the correlation matrix $\Omega$. I use the LKJ distribution with shape parameter $1$, which is the uniform distribution on the space of correlation matrices.

stancode1 <- 'data {
int<lower=1> N; // number of observations
int<lower=1> J; // dimension of observations
vector[J] y[N]; // observations
vector[J] Zero; // a vector of Zeros (fixed means of observations)
}
parameters {
corr_matrix[J] Omega;
vector<lower=0>[J] sigma;
}
transformed parameters {
cov_matrix[J] Sigma;
}
model {
y ~ multi_normal(Zero,Sigma); // sampling distribution of the observations
sigma ~ cauchy(0, 5); // prior on the standard deviations
Omega ~ lkj_corr(1); // LKJ prior on the correlation matrix
}'

The stan_model function of the rstan package runs the Stan compilation of the model.

stanmodel1 <- stan_model(model_code = stancode1, model_name="stanmodel1")

Note that this function only takes the Stan code as input. Once the model is compiled, the stanmodel object can be used to run the model on different datasets with the sampling function.

The data must be passed to the sampling function into a list:

standata <- list(J = ncol(y), N=N, y = y, Zero=rep(0, ncol(y)))

The algorithms used by Stan to generate the posterior distributions require initial values of the parameters. One can let the sampling function generate random initial values, or pass them in its init argument. I prefer to give my initial values. More precisely, one must pass to the init argument a function that returns the initial values in a list (see ?sampling).

Here is the way I use to generate initial values. Firstly, I write a function that returns some estimates of the parameters, taking the observations as input and allowing to randomly perturb these observations:

estimates <- function(y, perturb=FALSE){
if(perturb) y <- y + rnorm(length(y), 0, 1)
sigma <- sqrt(diag(var(y)))
Omega <- cor(y)
return(list(sigma=sigma, Omega=Omega))
}

I run Stan with several chains, for instance four chains. The function passed to the init argument of the sampling function takes an optional argument chain_id. For the first chain, I use the estimates calculated from the original data as initial values, and for the other chains I use the estimates calculated from the disturbed original data:

inits <- function(chain_id){
values <- estimates(standata\$y, perturb = chain_id > 1)
return(values)
}

Now we are ready to run Stan:

samples1 <- sampling(stanmodel1, data = standata, init=inits,
iter = 10000, warmup = 1000, chains = 4)

Some numerical problems occur but they are benign. It is not abnormal to get some messages such as

validate transformed params: Sigma is not positive definite
validate transformed params: Sigma is not symmetric. Sigma[1,2] = -nan, but Sigma[2,1] = -nan
Exception thrown at line 23: lkj_corr_log: y is not positive definite
validate transformed params: Sigma is not symmetric. Sigma[1,2] = inf, but Sigma[2,1] = inf

These problems will not occur with the LKJ prior on the Cholesky factor.

I like to use the coda package for output analysis. This is the way I use to store the samples in a coda object:

library(coda)
codasamples1 <- do.call(mcmc.list,
plyr::alply(rstan::extract(samples1,
pars=c("sigma", "Omega[1,2]", "Omega[1,3]", "Omega[2,3]"),
permuted=FALSE), 2, mcmc))

## Prior on the Cholesky factor

The correlation matrix $\Omega$ has a Cholesky factorization $\Omega = LL'$ where $L$ is a lower triangular matrix. Instead of assigning a prior distribution on $\Omega$, on can assign a prior dsitribution on $L$. By this way, the numerical problems encountered with the previous way are overcome, and this way is also better for a speed perspective.

stancode2 <- 'data {
int<lower=1> N; // number of observations
int<lower=1> J; // dimension of observations
vector[J] y[N]; // observations
vector[J] Zero; // a vector of Zeros (fixed means of observations)
}
parameters {
cholesky_factor_corr[J] Lcorr;
vector<lower=0>[J] sigma;
}
model {
y ~ multi_normal_cholesky(Zero, diag_pre_multiply(sigma, Lcorr));
sigma ~ cauchy(0, 5);
Lcorr ~ lkj_corr_cholesky(1);
}
generated quantities {
matrix[J,J] Omega;
matrix[J,J] Sigma;
Omega <- multiply_lower_tri_self_transpose(Lcorr);
}'

Note the generated quantities block as compared to the transformed quantities block in the first code. The objects declared in the transformed parameters block are intended to be used in the likelihood of the data, whereas the objects declared in the generated quantities block are not.

Now we only have to change the estimates function:

estimates <- function(y, perturb=FALSE){
if(perturb) y <- y + rnorm(length(y), 0, 1)
Lcorr <- t(chol(cor(y)))
sigma <- sqrt(diag(var(y)))
return(list(Lcorr=Lcorr, sigma=sigma))
}

Then compile and run as before:

stanmodel2 <- stan_model(model_code = stancode2, model_name="stanmodel2")
samples2 <- sampling(stanmodel2, data = standata, init=inits,
iter = 10000, warmup = 1000, chains = 4)
library(coda)
codasamples2 <- do.call(mcmc.list,
plyr::alply(rstan::extract(samples2,
pars=c("sigma", "Omega[1,2]", "Omega[1,3]", "Omega[2,3]"),
permuted=FALSE), 2, mcmc))

## Comparison of results

Results are almost identical:

summary(codasamples1)
##
## Iterations = 1:9000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 9000
##
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
##
##              Mean      SD  Naive SE Time-series SE
## sigma[1]   0.9132 0.06627 0.0003493      0.0004186
## sigma[2]   2.2080 0.16003 0.0008434      0.0009858
## sigma[3]   3.5181 0.25318 0.0013344      0.0016278
## Omega[1,2] 0.2118 0.09544 0.0005030      0.0005972
## Omega[1,3] 0.2380 0.09379 0.0004943      0.0005830
## Omega[2,3] 0.1477 0.09771 0.0005150      0.0006090
##
## 2. Quantiles for each variable:
##
##                2.5%     25%    50%    75%  97.5%
## sigma[1]    0.79449 0.86688 0.9094 0.9549 1.0545
## sigma[2]    1.92073 2.09639 2.1982 2.3093 2.5496
## sigma[3]    3.06211 3.34134 3.5048 3.6785 4.0492
## Omega[1,2]  0.01965 0.14766 0.2138 0.2779 0.3921
## Omega[1,3]  0.04763 0.17587 0.2404 0.3032 0.4136
## Omega[2,3] -0.04694 0.08168 0.1484 0.2149 0.3347
summary(codasamples2)
##
## Iterations = 1:9000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 9000
##
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
##
##              Mean      SD  Naive SE Time-series SE
## sigma[1]   0.9129 0.06625 0.0003492      0.0004088
## sigma[2]   2.2075 0.16108 0.0008489      0.0009963
## sigma[3]   3.5213 0.25203 0.0013283      0.0015606
## Omega[1,2] 0.2118 0.09509 0.0005012      0.0005955
## Omega[1,3] 0.2386 0.09420 0.0004965      0.0005896
## Omega[2,3] 0.1476 0.09670 0.0005096      0.0006101
##
## 2. Quantiles for each variable:
##
##                2.5%     25%    50%    75%  97.5%
## sigma[1]    0.79501 0.86661 0.9087 0.9550 1.0549
## sigma[2]    1.91973 2.09503 2.1978 2.3093 2.5487
## sigma[3]    3.07313 3.34663 3.5053 3.6820 4.0570
## Omega[1,2]  0.02205 0.14742 0.2140 0.2783 0.3910
## Omega[1,3]  0.04772 0.17610 0.2412 0.3040 0.4161
## Omega[2,3] -0.04386 0.08225 0.1490 0.2142 0.3310