-
Using a LKJ prior in Stan
2016-03-07
SourceThere are two ways to use a LKJ prior distribution for a correlation matrix in STAN. The first one assigns the distribution on the correlation matrix, whereas the second one assigns the distribution on the lower Cholesky factor of the correlation matrix. I am going to show an example for a trivariate normal sample with a fixed mean: \[ y_i \sim_{\text{iid}} {\cal N}\left( \begin{pmatrix} 0 \\ 0 \\ 0 \end{pmatrix}, \Sigma\right). \] Recall the relation between the covariance matrix and the correlation matrix: \[ \begin{pmatrix} \sigma_{1}^2 & \sigma_{12} & \sigma_{13} \\ \sigma_{12} & \sigma_2^2 & \sigma_{23} \\ \sigma_{13} & \sigma_{23} & \sigma_3^2 \end{pmatrix} = \begin{pmatrix} \sigma_{1}^2 & 0 & 0 \\ 0 & \sigma_2^2 & 0 \\ 0 & 0 & \sigma_3^2 \end{pmatrix} \Omega \begin{pmatrix} \sigma_{1}^2 & 0 & 0 \\ 0 & \sigma_2^2 & 0 \\ 0 & 0 & \sigma_3^2 \end{pmatrix} \] This operation is performed in Stan by the function
quad_form_diag.I do not assume the reader familiar with Stan or the
rstanpackage, so I will comment each step.First I load the
rstanpackage with the usual options:library(rstan) rstan_options(auto_write = TRUE) options(mc.cores = parallel::detectCores())I will run Stan on these simulated data:
set.seed(666) Omega <- rbind( c(1, 0.3, 0.2), c(0.3, 1, 0.1), c(0.2, 0.1, 1) ) sigma <- c(1, 2, 3) Sigma <- diag(sigma) %*% Omega %*% diag(sigma) N <- 100 y <- mvtnorm::rmvnorm(N, c(0,0,0), Sigma)Prior on the correlation matrix
Below is the Stan code for the Bayesian model assigning a LKJ prior on the correlation matrix \(\Omega\). I use the LKJ distribution with shape parameter \(1\), which is the uniform distribution on the space of correlation matrices.
stancode1 <- 'data { int<lower=1> N; // number of observations int<lower=1> J; // dimension of observations vector[J] y[N]; // observations vector[J] Zero; // a vector of Zeros (fixed means of observations) } parameters { corr_matrix[J] Omega; vector<lower=0>[J] sigma; } transformed parameters { cov_matrix[J] Sigma; Sigma <- quad_form_diag(Omega, sigma); } model { y ~ multi_normal(Zero,Sigma); // sampling distribution of the observations sigma ~ cauchy(0, 5); // prior on the standard deviations Omega ~ lkj_corr(1); // LKJ prior on the correlation matrix }'The
stan_modelfunction of therstanpackage runs the Stan compilation of the model.stanmodel1 <- stan_model(model_code = stancode1, model_name="stanmodel1")Note that this function only takes the Stan code as input. Once the model is compiled, the
stanmodelobject can be used to run the model on different datasets with thesamplingfunction.The data must be passed to the
samplingfunction into a list:standata <- list(J = ncol(y), N=N, y = y, Zero=rep(0, ncol(y)))The algorithms used by Stan to generate the posterior distributions require initial values of the parameters. One can let the
samplingfunction generate random initial values, or pass them in itsinitargument. I prefer to give my initial values. More precisely, one must pass to theinitargument a function that returns the initial values in a list (see?sampling).Here is the way I use to generate initial values. Firstly, I write a function that returns some estimates of the parameters, taking the observations as input and allowing to randomly perturb these observations:
estimates <- function(y, perturb=FALSE){ if(perturb) y <- y + rnorm(length(y), 0, 1) sigma <- sqrt(diag(var(y))) Omega <- cor(y) return(list(sigma=sigma, Omega=Omega)) }I run Stan with several chains, for instance four chains. The function passed to the
initargument of thesamplingfunction takes an optional argumentchain_id. For the first chain, I use the estimates calculated from the original data as initial values, and for the other chains I use the estimates calculated from the disturbed original data:inits <- function(chain_id){ values <- estimates(standata$y, perturb = chain_id > 1) return(values) }Now we are ready to run Stan:
samples1 <- sampling(stanmodel1, data = standata, init=inits, iter = 10000, warmup = 1000, chains = 4)Some numerical problems occur but they are benign. It is not abnormal to get some messages such as
validate transformed params: Sigma is not positive definite validate transformed params: Sigma is not symmetric. Sigma[1,2] = -nan, but Sigma[2,1] = -nan Exception thrown at line 23: lkj_corr_log: y is not positive definite validate transformed params: Sigma is not symmetric. Sigma[1,2] = inf, but Sigma[2,1] = infThese problems will not occur with the LKJ prior on the Cholesky factor.
I like to use the
codapackage for output analysis. This is the way I use to store the samples in acodaobject:library(coda) codasamples1 <- do.call(mcmc.list, plyr::alply(rstan::extract(samples1, pars=c("sigma", "Omega[1,2]", "Omega[1,3]", "Omega[2,3]"), permuted=FALSE), 2, mcmc))Prior on the Cholesky factor
The correlation matrix \(\Omega\) has a Cholesky factorization \(\Omega = LL'\) where \(L\) is a lower triangular matrix. Instead of assigning a prior distribution on \(\Omega\), on can assign a prior dsitribution on \(L\). By this way, the numerical problems encountered with the previous way are overcome, and this way is also better for a speed perspective.
stancode2 <- 'data { int<lower=1> N; // number of observations int<lower=1> J; // dimension of observations vector[J] y[N]; // observations vector[J] Zero; // a vector of Zeros (fixed means of observations) } parameters { cholesky_factor_corr[J] Lcorr; vector<lower=0>[J] sigma; } model { y ~ multi_normal_cholesky(Zero, diag_pre_multiply(sigma, Lcorr)); sigma ~ cauchy(0, 5); Lcorr ~ lkj_corr_cholesky(1); } generated quantities { matrix[J,J] Omega; matrix[J,J] Sigma; Omega <- multiply_lower_tri_self_transpose(Lcorr); Sigma <- quad_form_diag(Omega, sigma); }'Note the
generated quantitiesblock as compared to thetransformed quantitiesblock in the first code. The objects declared in thetransformed parametersblock are intended to be used in the likelihood of the data, whereas the objects declared in thegenerated quantitiesblock are not.Now we only have to change the
estimatesfunction:estimates <- function(y, perturb=FALSE){ if(perturb) y <- y + rnorm(length(y), 0, 1) Lcorr <- t(chol(cor(y))) sigma <- sqrt(diag(var(y))) return(list(Lcorr=Lcorr, sigma=sigma)) }Then compile and run as before:
stanmodel2 <- stan_model(model_code = stancode2, model_name="stanmodel2") samples2 <- sampling(stanmodel2, data = standata, init=inits, iter = 10000, warmup = 1000, chains = 4)library(coda) codasamples2 <- do.call(mcmc.list, plyr::alply(rstan::extract(samples2, pars=c("sigma", "Omega[1,2]", "Omega[1,3]", "Omega[2,3]"), permuted=FALSE), 2, mcmc))Comparison of results
Results are almost identical:
summary(codasamples1)## ## Iterations = 1:9000 ## Thinning interval = 1 ## Number of chains = 4 ## Sample size per chain = 9000 ## ## 1. Empirical mean and standard deviation for each variable, ## plus standard error of the mean: ## ## Mean SD Naive SE Time-series SE ## sigma[1] 0.9132 0.06627 0.0003493 0.0004186 ## sigma[2] 2.2080 0.16003 0.0008434 0.0009858 ## sigma[3] 3.5181 0.25318 0.0013344 0.0016278 ## Omega[1,2] 0.2118 0.09544 0.0005030 0.0005972 ## Omega[1,3] 0.2380 0.09379 0.0004943 0.0005830 ## Omega[2,3] 0.1477 0.09771 0.0005150 0.0006090 ## ## 2. Quantiles for each variable: ## ## 2.5% 25% 50% 75% 97.5% ## sigma[1] 0.79449 0.86688 0.9094 0.9549 1.0545 ## sigma[2] 1.92073 2.09639 2.1982 2.3093 2.5496 ## sigma[3] 3.06211 3.34134 3.5048 3.6785 4.0492 ## Omega[1,2] 0.01965 0.14766 0.2138 0.2779 0.3921 ## Omega[1,3] 0.04763 0.17587 0.2404 0.3032 0.4136 ## Omega[2,3] -0.04694 0.08168 0.1484 0.2149 0.3347summary(codasamples2)## ## Iterations = 1:9000 ## Thinning interval = 1 ## Number of chains = 4 ## Sample size per chain = 9000 ## ## 1. Empirical mean and standard deviation for each variable, ## plus standard error of the mean: ## ## Mean SD Naive SE Time-series SE ## sigma[1] 0.9129 0.06625 0.0003492 0.0004088 ## sigma[2] 2.2075 0.16108 0.0008489 0.0009963 ## sigma[3] 3.5213 0.25203 0.0013283 0.0015606 ## Omega[1,2] 0.2118 0.09509 0.0005012 0.0005955 ## Omega[1,3] 0.2386 0.09420 0.0004965 0.0005896 ## Omega[2,3] 0.1476 0.09670 0.0005096 0.0006101 ## ## 2. Quantiles for each variable: ## ## 2.5% 25% 50% 75% 97.5% ## sigma[1] 0.79501 0.86661 0.9087 0.9550 1.0549 ## sigma[2] 1.91973 2.09503 2.1978 2.3093 2.5487 ## sigma[3] 3.07313 3.34663 3.5053 3.6820 4.0570 ## Omega[1,2] 0.02205 0.14742 0.2140 0.2783 0.3910 ## Omega[1,3] 0.04772 0.17610 0.2412 0.3040 0.4161 ## Omega[2,3] -0.04386 0.08225 0.1490 0.2142 0.3310
