Computational Tricks with Turing
(Non-Centered Parametrization
and QR Decomposition)

There are some computational tricks that we can employ with Turing. I will cover here two computational tricks:

  1. QR Decomposition

  2. Non-Centered Parametrization

QR Decomposition

Back in "Linear Algebra 101" we've learned that any matrix (even rectangular ones) can be factored into the product of two matrices:

  • \(\mathbf{Q}\): an orthogonal matrix (its columns are orthogonal unit vectors meaning \(\mathbf{Q}^T = \mathbf{Q}^{-1})\).

  • \(\mathbf{R}\): an upper triangular matrix.

This is commonly known as the QR Decomposition:

\[ \mathbf{A} = \mathbf{Q} \cdot \mathbf{R} \]

Let me show you an example with a random matrix \(\mathbf{A} \in \mathbb{R}^{3 \times 2}\):

A = rand(3, 2)
3×2 Matrix{Float64}:
 0.768448  0.395453
 0.940515  0.313244
 0.673959  0.662555

Now let's factor A using LinearAlgebra's qr() function:

using LinearAlgebra:qr, I
Q, R = qr(A)
LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}}
Q factor:
3×3 LinearAlgebra.QRCompactWYQ{Float64, Matrix{Float64}}:
 -0.553241  -0.0582294  -0.830984
 -0.67712   -0.549615    0.489317
 -0.485214   0.833386    0.264641
R factor:
2×2 Matrix{Float64}:
 -1.38899  -0.752366
  0.0       0.356973

Notice that qr() produced a tuple containing two matrices Q and R. Q is a 3x3 orthogonal matrix. And R is a 2x2 upper triangular matrix. So that \(\mathbf{Q}^T = \mathbf{Q}^{-1}\) (the transpose is equal the inverse):

Matrix(Q') ≈ Matrix(Q^-1)
true

Also note that \(\mathbf{Q}^T \cdot \mathbf{Q}^{-1} = \mathbf{I}\) (identity matrix):

Q' * Q ≈ I(3)
true

This is nice. But what can we do with QR decomposition? It can speed up Turing's sampling by a huge factor while also decorrelating the columns of \(\mathbf{X}\), i.e. the independent variables. The orthogonal nature of QR decomposition alters the posterior's topology and makes it easier for HMC or other MCMC samplers to explore it. Let's see how fast we can get with QR decomposition. First, let's go back to the kidiq example in 6. Bayesian Linear Regression:

using Turing
using Statistics: mean, std
using Random:seed!
seed!(123)

@model linreg(X, y; predictors=size(X, 2)) = begin
    #priors
    α ~ Normal(mean(y), 2.5 * std(y))
    β ~ filldist(TDist(3), predictors)
    σ ~ Exponential(1)

    #likelihood
    y ~ MvNormal(α .+ X * β, σ)
end;

using DataFrames, CSV, HTTP

url = "https://raw.githubusercontent.com/storopoli/Bayesian-Julia/master/datasets/kidiq.csv"
kidiq = CSV.read(HTTP.get(url).body, DataFrame)
X = Matrix(select(kidiq, Not(:kid_score)))
y = kidiq[:, :kid_score]
model = linreg(X, y)
chain = sample(model, NUTS(), MCMCThreads(), 2_000, 4)
Chains MCMC chain (2000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 34.03 seconds
Compute duration  = 61.52 seconds
parameters        = α, β[1], β[2], β[3], σ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

           α   21.3539    8.5767     0.0959    0.1630   2991.2341    1.0014       48.6190
        β[1]    2.0203    1.8166     0.0203    0.0325   3507.2274    1.0001       57.0058
        β[2]    0.5813    0.0579     0.0006    0.0009   4755.0340    1.0009       77.2875
        β[3]    0.2521    0.3059     0.0034    0.0053   3560.7804    1.0009       57.8763
           σ   17.8837    0.5851     0.0065    0.0070   6269.2254    1.0000      101.8989

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           α    4.1451   15.6520   21.4701   26.9811   38.0069
        β[1]   -0.6023    0.7092    1.6688    2.9935    6.3542
        β[2]    0.4691    0.5431    0.5805    0.6202    0.6967
        β[3]   -0.3391    0.0408    0.2503    0.4560    0.8466
           σ   16.7732   17.4781   17.8677   18.2658   19.0710

See the wall duration in Turing's chain: for me it took around 24 seconds.

Now let's us incorporate QR decomposition in the linear regression model. Here, I will use the "thin" instead of the "fat" QR, which scales the \(\mathbf{Q}\) and \(\mathbf{R}\) matrices by a factor of \(\sqrt{n-1}\) where \(n\) is the number of rows of \(\mathbf{X}\). In practice it is better to implement the thin QR decomposition, which is to be preferred to the fat QR decomposition. It is numerically more stable. Mathematically, the thin QR decomposition is:

\[ \begin{aligned} x &= \mathbf{Q}^* \mathbf{R}^* \\ \mathbf{Q}^* &= \mathbf{Q} \cdot \sqrt{n - 1} \\ \mathbf{R}^* &= \frac{1}{\sqrt{n - 1}} \cdot \mathbf{R}\\ \boldsymbol{\mu} &= \alpha + \mathbf{X} \cdot \boldsymbol{\beta} + \sigma \\ &= \alpha + \mathbf{Q}^* \cdot \mathbf{R}^* \cdot \boldsymbol{\beta} + \sigma \\ &= \alpha + \mathbf{Q}^* \cdot (\mathbf{R}^* \cdot \boldsymbol{\beta}) + \sigma \\ &= \alpha + \mathbf{Q}^* \cdot \widetilde{\boldsymbol{\beta}} + \sigma \\ \end{aligned} \]

Then we can recover the original \(\boldsymbol{\beta}\) with:

\[ \boldsymbol{\beta} = \mathbf{R}^{*-1} \cdot \widetilde{\boldsymbol{\beta}} \]

In Turing, a model with QR decomposition would be the same linreg but with a different X matrix supplied, since it is a data transformation. First, we decompose your model data X into Q and R:

Q, R = qr(X)
Q_ast = Matrix(Q) * sqrt(size(X, 1) - 1)
R_ast = R / sqrt(size(X, 1) - 1);

Then, we instantiate a model with Q instead of X and sample as you would:

model_qr = linreg(Q_ast, y)
chain_qr = sample(model_qr, NUTS(1_000, 0.65), MCMCThreads(), 2_000, 4)
Chains MCMC chain (2000×17×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 9.43 seconds
Compute duration  = 18.61 seconds
parameters        = α, β[1], β[2], β[3], σ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol    Float64   Float64    Float64   Float64     Float64   Float64       Float64

           α    33.3373    8.3099     0.0929    0.1926   1755.9144    1.0021       94.3330
        β[1]   -49.6145    7.4389     0.0832    0.1705   1762.0871    1.0020       94.6646
        β[2]    21.8728    3.7826     0.0423    0.0858   1812.6498    1.0019       97.3810
        β[3]     0.3164    0.9437     0.0106    0.0205   2205.2610    1.0016      118.4732
           σ    17.8661    0.6047     0.0068    0.0088   5027.7919    1.0002      270.1081

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%      97.5%
      Symbol    Float64    Float64    Float64    Float64    Float64

           α    17.7029    27.7884    33.3142    38.5484    50.0087
        β[1]   -63.6119   -54.6095   -49.6650   -44.9534   -34.7229
        β[2]    14.2781    19.5114    21.8633    24.4282    29.1062
        β[3]    -1.4035    -0.2897     0.2787     0.8569     2.3080
           σ    16.7390    17.4417    17.8480    18.2647    19.1166

See the wall duration in Turing's chain_qr: for me it took around 5 seconds. Much faster than the regular linreg. Now we have to reconstruct our \(\boldsymbol{\beta}\)s:

betas = mapslices(x -> R_ast^-1 * x, chain_qr[:, namesingroup(chain_qr, :β),:].value.data, dims=[2])
chain_qr_reconstructed = hcat(Chains(betas, ["real_β[$i]" for i in 1:size(Q_ast, 2)]), chain_qr)
ArgumentError: chain ranges differ

Non-Centered Parametrization

Now let's us explore Non-Centered Parametrization (NCP). This is useful when the posterior's topology is very difficult to explore as has regions where HMC sampler has to change the step size \(L\) and the \(\epsilon\) factor. This is I've showed one of the most infamous case in 5. Markov Chain Monte Carlo (MCMC): Neal's Funnel (Neal, 2003):

using StatsPlots, Distributions, LaTeXStrings
funnel_y = rand(Normal(0, 3), 10_000)
funnel_x = rand(Normal(), 10_000) .* exp.(funnel_y / 2)

scatter((funnel_x, funnel_y),
        label=false, mc=:steelblue, ma=0.3,
        xlabel=L"X", ylabel=L"Y",
        xlims=(-100, 100))

Neal's Funnel

Here we see that in upper part of the funnel HMC has to take few steps \(L\) and be more liberal with the \(\epsilon\) factor. But, the opposite is in the lower part of the funnel: way more steps \(L\) and be more conservative with the \(\epsilon\) factor.

To see the devil's funnel (how it is known in some Bayesian circles) in action, let's code it in Turing and then sample:

@model funnel() = begin
    y ~ Normal(0, 3)
    x ~ Normal(0, exp(y / 2))
end

    chain_funnel = sample(funnel(), NUTS(), MCMCThreads(), 2_000, 4)
Chains MCMC chain (2000×14×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 6.67 seconds
Compute duration  = 12.08 seconds
parameters        = y, x
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64

           y    1.0247    2.3578     0.0264    0.1430   157.9554    1.0366       13.0779
           x    0.9257   11.2385     0.1256    0.5332   427.3839    1.0081       35.3853

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           y   -2.6127   -0.7645    0.7627    2.5280    6.3949
           x   -9.1671   -0.8393    0.0131    0.9765   15.2710

Wow, take a look at those rhat values... That sucks: all are above 1.01 even with 4 parallel chains with 2,000 iterations!

How do we deal with that? We reparametrize! Note that we can add two normal distributions in the following manner:

\[ \text{Normal}(\mu, \sigma) = \text{Standard Normal} \cdot \sigma + \mu \]

where the standard normal is the normal with mean \(\mu = 0\) and standard deviation \(\sigma = 1\). This is why is called Non-Centered Parametrization because we "decouple" the parameters and reconstruct them before.

@model ncp_funnel() = begin
    x̃ ~ Normal()
    ỹ ~ Normal()
    y = 3.0 * ỹ         # implies y ~ Normal(0, 3)
    x = exp(y / 2) * x̃  # implies x ~ Normal(0, exp(y / 2))
end

chain_ncp_funnel = sample(ncp_funnel(), NUTS(), MCMCThreads(), 2_000, 4)
Chains MCMC chain (2000×14×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 5.71 seconds
Compute duration  = 10.41 seconds
parameters        = x̃, ỹ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64

           x̃    0.0159    0.9903     0.0111    0.0116   8183.3629    1.0000      785.8795
           ỹ    0.0092    1.0112     0.0113    0.0108   7764.3380    0.9997      745.6389

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x̃   -1.9347   -0.6360    0.0118    0.6726    1.9777
           ỹ   -2.0186   -0.6794    0.0014    0.6948    1.9848

Much better now: all rhat are well below 1.01 (or above 0.99).

How we would implement this a real-world model in Turing? Let's go back to the cheese random-intercept model in 10. Multilevel Models (a.k.a. Hierarchical Models). Here was the approach that we took, also known as Centered Parametrization (CP):

@model varying_intercept(X, idx, y; n_gr=length(unique(idx)), predictors=size(X, 2)) = begin
    #priors
    α ~ Normal(mean(y), 2.5 * std(y))       # population-level intercept
    β ~ filldist(Normal(0, 2), predictors)  # population-level coefficients
    σ ~ Exponential(1 / std(y))             # residual SD
    #prior for variance of random intercepts
    #usually requires thoughtful specification
    τ ~ truncated(Cauchy(0, 2), 0, Inf)     # group-level SDs intercepts
    αⱼ ~ filldist(Normal(0, τ), n_gr)       # CP group-level intercepts

    #likelihood
    ŷ = α .+ X * β .+ αⱼ[idx]
    y ~ MvNormal(ŷ, σ)
end;

To perform a Non-Centered Parametrization (NCP) in this model we do as following:

@model varying_intercept_ncp(X, idx, y; n_gr=length(unique(idx)), predictors=size(X, 2)) = begin
    #priors
    α ~ Normal(mean(y), 2.5 * std(y))       # population-level intercept
    β ~ filldist(Normal(0, 2), predictors)  # population-level coefficients
    σ ~ Exponential(1 / std(y))             # residual SD

    #prior for variance of random intercepts
    #usually requires thoughtful specification
    τ ~ truncated(Cauchy(0, 2), 0, Inf)    # group-level SDs intercepts
    zⱼ ~ filldist(Normal(0, 1), n_gr)      # NCP group-level intercepts

    #likelihood
    ŷ = α .+ X * β .+ zⱼ[idx] .* τ
    y ~ MvNormal(ŷ, σ)
end;

Here we are using a NCP with the zⱼs following a standard normal and we reconstruct the group-level intercepts by multiplying the zⱼs by τ. Since the original αⱼs had a prior centered on 0 with standard deviation τ, we only have to use the multiplication by τ to get back the αⱼs.

Now let's see how NCP compares to the CP. First, let's redo our CP hierarchical model:

url = "https://raw.githubusercontent.com/storopoli/Bayesian-Julia/master/datasets/cheese.csv"
cheese = CSV.read(HTTP.get(url).body, DataFrame)

for c in unique(cheese[:, :cheese])
    cheese[:, "cheese_$c"] = ifelse.(cheese[:, :cheese] .== c, 1, 0)
end

cheese[:, :background_int] = map(cheese[:, :background]) do b
    b == "rural" ? 1 :
    b == "urban" ? 2 : missing
end

X = Matrix(select(cheese, Between(:cheese_A, :cheese_D)));
y = cheese[:, :y];
idx = cheese[:, :background_int];

model_cp = varying_intercept(X, idx, y)
chain_cp = sample(model_cp, NUTS(), MCMCThreads(), 2_000, 4)
Chains MCMC chain (2000×21×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 23.4 seconds
Compute duration  = 44.33 seconds
parameters        = α, β[1], β[2], β[3], β[4], σ, τ, αⱼ[1], αⱼ[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean       std   naive_se      mcse         ess      rhat   ess_per_sec
      Symbol    Float64   Float64    Float64   Float64     Float64   Float64       Float64

           α    70.4886    5.7041     0.0638    0.1727   1010.4968    1.0012       22.7933
        β[1]     3.2192    1.2523     0.0140    0.0197   3987.0485    0.9999       89.9341
        β[2]   -11.5989    1.2583     0.0141    0.0214   3875.0906    1.0000       87.4087
        β[3]     7.1717    1.2649     0.0141    0.0192   4023.9893    1.0002       90.7674
        β[4]     1.1961    1.2577     0.0141    0.0203   3894.9953    1.0005       87.8577
           σ     5.9951    0.2704     0.0030    0.0038   5738.7255    1.0006      129.4459
           τ     6.6607    8.3460     0.0933    0.2483   1121.0806    1.0012       25.2877
       αⱼ[1]    -3.2320    5.5907     0.0625    0.1732    987.6553    1.0014       22.2781
       αⱼ[2]     3.9360    5.6192     0.0628    0.1734    987.3585    1.0014       22.2714

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%     97.5%
      Symbol    Float64    Float64    Float64    Float64   Float64

           α    58.4933    68.2771    70.7586    73.0395   80.6218
        β[1]     0.7105     2.3874     3.2102     4.0803    5.6169
        β[2]   -14.0694   -12.4370   -11.6061   -10.7619   -9.1051
        β[3]     4.7020     6.3027     7.1717     8.0374    9.6486
        β[4]    -1.2477     0.3423     1.1967     2.0355    3.6902
           σ     5.4904     5.8076     5.9877     6.1762    6.5408
           τ     1.9269     3.2850     4.6911     7.2780   24.3742
       αⱼ[1]   -13.0467    -5.6029    -3.4369    -1.2280    8.6039
       αⱼ[2]    -5.8915     1.5470     3.6580     5.9007   16.0866

Now let's do the NCP hierarchical model:

model_ncp = varying_intercept_ncp(X, idx, y)
chain_ncp = sample(model_ncp, NUTS(), MCMCThreads(), 2_000, 4)
Chains MCMC chain (2000×21×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 24.34 seconds
Compute duration  = 43.9 seconds
parameters        = α, β[1], β[2], β[3], β[4], σ, τ, zⱼ[1], zⱼ[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters       mean       std   naive_se      mcse        ess      rhat   ess_per_sec
      Symbol    Float64   Float64    Float64   Float64    Float64   Float64       Float64

           α    69.1483    6.4983     0.0727    0.5631    30.9753    1.3408        0.7056
        β[1]     3.0064    1.3008     0.0145    0.0625   124.6907    1.0534        2.8405
        β[2]   -11.6189    1.2382     0.0138    0.0445   278.6275    1.0025        6.3473
        β[3]     7.0715    1.1941     0.0134    0.0355   786.2910    1.0087       17.9122
        β[4]     1.1397    1.1957     0.0134    0.0395   523.0683    1.0048       11.9158
           σ     6.0787    0.3225     0.0036    0.0206    52.0977    1.1707        1.1868
           τ     7.3936    6.0720     0.0679    0.5776    25.0122    1.5117        0.5698
       zⱼ[1]    -0.6853    0.8774     0.0098    0.0540    61.8678    1.1394        1.4094
       zⱼ[2]     0.8395    0.7707     0.0086    0.0272   620.4978    1.0057       14.1353

Quantiles
  parameters       2.5%      25.0%      50.0%      75.0%     97.5%
      Symbol    Float64    Float64    Float64    Float64   Float64

           α    54.3944    67.0281    70.2877    72.8053   80.2180
        β[1]     0.7056     2.0629     2.9488     3.9223    5.6000
        β[2]   -13.9968   -12.4436   -11.6303   -10.7272   -9.1887
        β[3]     4.6696     6.3018     7.0517     7.8326    9.5307
        β[4]    -1.2361     0.3443     1.1388     1.9241    3.5450
           σ     5.5173     5.8449     6.0484     6.3054    6.7493
           τ     1.9870     3.4090     5.0187     8.5101   23.2911
       zⱼ[1]    -2.3688    -1.3098    -0.7027    -0.0165    0.7557
       zⱼ[2]    -0.7225     0.3784     0.8412     1.2879    2.4046

Notice that some models are better off with a standard Centered Parametrization (as is our cheese case here). While others are better off with a Non-Centered Parametrization. But now you know how to apply both parametrizations in Turing. Before we conclude, we need to recover our original αⱼs. We can do this by multiplying zⱼ[idx] .* τ:

τ = summarystats(chain_ncp)[:τ, :mean]
αⱼ = mapslices(x -> x * τ, chain_ncp[:,namesingroup(chain_ncp, :zⱼ),:].value.data, dims=[2])
chain_ncp_reconstructed = hcat(Chains(αⱼ, ["αⱼ[$i]" for i in 1:length(unique(idx))]), chain_ncp)
ArgumentError: chain ranges differ

References

Neal, Radford M. (2003). Slice Sampling. The Annals of Statistics, 31(3), 705–741. Retrieved from https://www.jstor.org/stable/3448413