-
Using a LKJ prior in Stan
2016-03-07
SourceThere are two ways to use a LKJ prior distribution for a correlation matrix in STAN. The first one assigns the distribution on the correlation matrix, whereas the second one assigns the distribution on the lower Cholesky factor of the correlation matrix. I am going to show an example for a trivariate normal sample with a fixed mean: \[ y_i \sim_{\text{iid}} {\cal N}\left( \begin{pmatrix} 0 \\ 0 \\ 0 \end{pmatrix}, \Sigma\right). \] Recall the relation between the covariance matrix and the correlation matrix: \[ \begin{pmatrix} \sigma_{1}^2 & \sigma_{12} & \sigma_{13} \\ \sigma_{12} & \sigma_2^2 & \sigma_{23} \\ \sigma_{13} & \sigma_{23} & \sigma_3^2 \end{pmatrix} = \begin{pmatrix} \sigma_{1}^2 & 0 & 0 \\ 0 & \sigma_2^2 & 0 \\ 0 & 0 & \sigma_3^2 \end{pmatrix} \Omega \begin{pmatrix} \sigma_{1}^2 & 0 & 0 \\ 0 & \sigma_2^2 & 0 \\ 0 & 0 & \sigma_3^2 \end{pmatrix} \] This operation is performed in Stan by the function
quad_form_diag
.I do not assume the reader familiar with Stan or the
rstan
package, so I will comment each step.First I load the
rstan
package with the usual options:library(rstan) rstan_options(auto_write = TRUE) options(mc.cores = parallel::detectCores())
I will run Stan on these simulated data:
set.seed(666) Omega <- rbind( c(1, 0.3, 0.2), c(0.3, 1, 0.1), c(0.2, 0.1, 1) ) sigma <- c(1, 2, 3) Sigma <- diag(sigma) %*% Omega %*% diag(sigma) N <- 100 y <- mvtnorm::rmvnorm(N, c(0,0,0), Sigma)
Prior on the correlation matrix
Below is the Stan code for the Bayesian model assigning a LKJ prior on the correlation matrix \(\Omega\). I use the LKJ distribution with shape parameter \(1\), which is the uniform distribution on the space of correlation matrices.
stancode1 <- 'data { int<lower=1> N; // number of observations int<lower=1> J; // dimension of observations vector[J] y[N]; // observations vector[J] Zero; // a vector of Zeros (fixed means of observations) } parameters { corr_matrix[J] Omega; vector<lower=0>[J] sigma; } transformed parameters { cov_matrix[J] Sigma; Sigma <- quad_form_diag(Omega, sigma); } model { y ~ multi_normal(Zero,Sigma); // sampling distribution of the observations sigma ~ cauchy(0, 5); // prior on the standard deviations Omega ~ lkj_corr(1); // LKJ prior on the correlation matrix }'
The
stan_model
function of therstan
package runs the Stan compilation of the model.stanmodel1 <- stan_model(model_code = stancode1, model_name="stanmodel1")
Note that this function only takes the Stan code as input. Once the model is compiled, the
stanmodel
object can be used to run the model on different datasets with thesampling
function.The data must be passed to the
sampling
function into a list:standata <- list(J = ncol(y), N=N, y = y, Zero=rep(0, ncol(y)))
The algorithms used by Stan to generate the posterior distributions require initial values of the parameters. One can let the
sampling
function generate random initial values, or pass them in itsinit
argument. I prefer to give my initial values. More precisely, one must pass to theinit
argument a function that returns the initial values in a list (see?sampling
).Here is the way I use to generate initial values. Firstly, I write a function that returns some estimates of the parameters, taking the observations as input and allowing to randomly perturb these observations:
estimates <- function(y, perturb=FALSE){ if(perturb) y <- y + rnorm(length(y), 0, 1) sigma <- sqrt(diag(var(y))) Omega <- cor(y) return(list(sigma=sigma, Omega=Omega)) }
I run Stan with several chains, for instance four chains. The function passed to the
init
argument of thesampling
function takes an optional argumentchain_id
. For the first chain, I use the estimates calculated from the original data as initial values, and for the other chains I use the estimates calculated from the disturbed original data:inits <- function(chain_id){ values <- estimates(standata$y, perturb = chain_id > 1) return(values) }
Now we are ready to run Stan:
samples1 <- sampling(stanmodel1, data = standata, init=inits, iter = 10000, warmup = 1000, chains = 4)
Some numerical problems occur but they are benign. It is not abnormal to get some messages such as
validate transformed params: Sigma is not positive definite validate transformed params: Sigma is not symmetric. Sigma[1,2] = -nan, but Sigma[2,1] = -nan Exception thrown at line 23: lkj_corr_log: y is not positive definite validate transformed params: Sigma is not symmetric. Sigma[1,2] = inf, but Sigma[2,1] = inf
These problems will not occur with the LKJ prior on the Cholesky factor.
I like to use the
coda
package for output analysis. This is the way I use to store the samples in acoda
object:library(coda) codasamples1 <- do.call(mcmc.list, plyr::alply(rstan::extract(samples1, pars=c("sigma", "Omega[1,2]", "Omega[1,3]", "Omega[2,3]"), permuted=FALSE), 2, mcmc))
Prior on the Cholesky factor
The correlation matrix \(\Omega\) has a Cholesky factorization \(\Omega = LL'\) where \(L\) is a lower triangular matrix. Instead of assigning a prior distribution on \(\Omega\), on can assign a prior dsitribution on \(L\). By this way, the numerical problems encountered with the previous way are overcome, and this way is also better for a speed perspective.
stancode2 <- 'data { int<lower=1> N; // number of observations int<lower=1> J; // dimension of observations vector[J] y[N]; // observations vector[J] Zero; // a vector of Zeros (fixed means of observations) } parameters { cholesky_factor_corr[J] Lcorr; vector<lower=0>[J] sigma; } model { y ~ multi_normal_cholesky(Zero, diag_pre_multiply(sigma, Lcorr)); sigma ~ cauchy(0, 5); Lcorr ~ lkj_corr_cholesky(1); } generated quantities { matrix[J,J] Omega; matrix[J,J] Sigma; Omega <- multiply_lower_tri_self_transpose(Lcorr); Sigma <- quad_form_diag(Omega, sigma); }'
Note the
generated quantities
block as compared to thetransformed quantities
block in the first code. The objects declared in thetransformed parameters
block are intended to be used in the likelihood of the data, whereas the objects declared in thegenerated quantities
block are not.Now we only have to change the
estimates
function:estimates <- function(y, perturb=FALSE){ if(perturb) y <- y + rnorm(length(y), 0, 1) Lcorr <- t(chol(cor(y))) sigma <- sqrt(diag(var(y))) return(list(Lcorr=Lcorr, sigma=sigma)) }
Then compile and run as before:
stanmodel2 <- stan_model(model_code = stancode2, model_name="stanmodel2") samples2 <- sampling(stanmodel2, data = standata, init=inits, iter = 10000, warmup = 1000, chains = 4)
library(coda) codasamples2 <- do.call(mcmc.list, plyr::alply(rstan::extract(samples2, pars=c("sigma", "Omega[1,2]", "Omega[1,3]", "Omega[2,3]"), permuted=FALSE), 2, mcmc))
Comparison of results
Results are almost identical:
summary(codasamples1)
## ## Iterations = 1:9000 ## Thinning interval = 1 ## Number of chains = 4 ## Sample size per chain = 9000 ## ## 1. Empirical mean and standard deviation for each variable, ## plus standard error of the mean: ## ## Mean SD Naive SE Time-series SE ## sigma[1] 0.9132 0.06627 0.0003493 0.0004186 ## sigma[2] 2.2080 0.16003 0.0008434 0.0009858 ## sigma[3] 3.5181 0.25318 0.0013344 0.0016278 ## Omega[1,2] 0.2118 0.09544 0.0005030 0.0005972 ## Omega[1,3] 0.2380 0.09379 0.0004943 0.0005830 ## Omega[2,3] 0.1477 0.09771 0.0005150 0.0006090 ## ## 2. Quantiles for each variable: ## ## 2.5% 25% 50% 75% 97.5% ## sigma[1] 0.79449 0.86688 0.9094 0.9549 1.0545 ## sigma[2] 1.92073 2.09639 2.1982 2.3093 2.5496 ## sigma[3] 3.06211 3.34134 3.5048 3.6785 4.0492 ## Omega[1,2] 0.01965 0.14766 0.2138 0.2779 0.3921 ## Omega[1,3] 0.04763 0.17587 0.2404 0.3032 0.4136 ## Omega[2,3] -0.04694 0.08168 0.1484 0.2149 0.3347
summary(codasamples2)
## ## Iterations = 1:9000 ## Thinning interval = 1 ## Number of chains = 4 ## Sample size per chain = 9000 ## ## 1. Empirical mean and standard deviation for each variable, ## plus standard error of the mean: ## ## Mean SD Naive SE Time-series SE ## sigma[1] 0.9129 0.06625 0.0003492 0.0004088 ## sigma[2] 2.2075 0.16108 0.0008489 0.0009963 ## sigma[3] 3.5213 0.25203 0.0013283 0.0015606 ## Omega[1,2] 0.2118 0.09509 0.0005012 0.0005955 ## Omega[1,3] 0.2386 0.09420 0.0004965 0.0005896 ## Omega[2,3] 0.1476 0.09670 0.0005096 0.0006101 ## ## 2. Quantiles for each variable: ## ## 2.5% 25% 50% 75% 97.5% ## sigma[1] 0.79501 0.86661 0.9087 0.9550 1.0549 ## sigma[2] 1.91973 2.09503 2.1978 2.3093 2.5487 ## sigma[3] 3.07313 3.34663 3.5053 3.6820 4.0570 ## Omega[1,2] 0.02205 0.14742 0.2140 0.2783 0.3910 ## Omega[1,3] 0.04772 0.17610 0.2412 0.3040 0.4161 ## Omega[2,3] -0.04386 0.08225 0.1490 0.2142 0.3310