Symbolic Simplification
January 25, 2021
Upcoming features in Soss.jl include static model simplification. After a one-time compilation cost, posterior log-densities for many models become "constant cost", independent of the number of observations. Bayesian analysis for such models can easily scale to big data.
The symbolic representation of the posterior log-density can also be useful for pedagogical purposes.
First, a Model
Let's start with something simple, say a linear regression with 10,000 observations.
using Soss
N = 10000
m = @model x, λ begin
σ ~ Exponential(λ)
β ~ Normal(0,1)
y ~ For(1:N) do j
Normal(x[j] * β, σ)
end
return y
end
Soss models are generative, so we can use one to generate data. This is useful to help make sure the model assumptions are reasonable (fake data should look similar to real data), and it's also handy for little examples like this.
To generate data from m
, we need to have values for the input arguments x
and λ
:
x = randn(N)
λ = 1.0
Given this, we could just call rand(m(x=x,λ=λ))
. But this would "forget" the values of σ
and β
. To instead keep track of these, we can do
trace = simulate(m(x=x, λ=1.0)).trace
y = trace.y
As a result, we now have
trace.σ = 0.3974449483739133
trace.β = 0.5236103489251422
mean(y) = -0.009302852896095896
std(y) = 0.6630927798796098
Evaluating the Log-Density (The Punchline)
In Bayesian modeling, we're usually in a situation of having observed y
, and wanting to infer what we can about latent parameters like σ
and β
. So let's construct the posterior distribution. We could write this as post = m(x=x, λ=λ) | (y=y,)
, or using the shorthand
post = m(; x, λ) | (; y)
To evaluate this, say on the trace
from the original sample, we just call
logdensity(post, trace)
which in this case gives
4123.1269713242
Let's see how fast it is.
using BenchmarkTools
slow = @benchmark $logdensity($post, $trace)
BenchmarkTools.Trial:
memory estimate: 1.07 MiB
allocs estimate: 60032
--------------
minimum time: 2.879 ms (0.00% GC)
median time: 3.002 ms (0.00% GC)
mean time: 3.246 ms (4.65% GC)
maximum time: 18.201 ms (81.89% GC)
--------------
samples: 1540
evals/sample: 1
Ok, clearly some things can be improved, for example we should be able to do this without allocation.
But rather than dwell too much on that, let's jump to the punchline:
ℓ = symlogdensity(post)
f = codegen(post; ℓ=ℓ)
fast = @benchmark $f($(;x,λ), $(;y), $trace)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 84.730 ns (0.00% GC)
median time: 85.937 ns (0.00% GC)
mean time: 86.678 ns (0.00% GC)
maximum time: 139.092 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 963
Note the units; we started out in milliseconds, but now we're in nanoseconds. A quick minimum(slow).time / minimum(fast).time
shows we've gotten a 33980× speedup.
How it Works
ℓ
is a symbolic expression from SymbolicUtils.jl
. To build it, we first generate a trace, as we did above. Now we have all of the types, so we can evaluate the log-density again, this time replacing the values with symbolic variables of the appropriate type.
At this point, we have a symbolic expression, and we also have a dictionary of fixed values from the arguments and observed values. So we expand the expression and then walk the result, looking for subexpressions that evaluate to scalars with no free variables. When we find one, we use the dictionary to rewrite it.
Next, we do common subexpression elimination, which just makes sure we don't evaluate the same thing twice.
Soss.cse(ℓ)
11-element Vector{Pair{Symbol, SymbolicUtils.Symbolic}}:
Symbol("##1212") => -σ
Symbol("##1213") => log(σ)
Symbol("##1214") => -10000var"##1213"
Symbol("##1215") => β^2
Symbol("##1216") => -0.5var"##1215"
Symbol("##1217") => -10516.391601146239β
Symbol("##1218") => 9926.088553874832var"##1215"
Symbol("##1219") => 4397.346085970019 + var"##1217" + var"##1218"
Symbol("##1220") => σ^-2
Symbol("##1221") => -0.5var"##1219"*var"##1220"
Symbol("##1222") => var"##1212" + var"##1214" + var"##1216" + var"##1221"
It's a very short step to get from that to a Julia Expr
that does what we want:
sourceCodegen(post)
quote
var"##1275" = (*)(-1.0, σ)
var"##1276" = (log)(σ)
var"##1277" = (*)(-10000, var"##1276")
var"##1278" = (^)(β, 2)
var"##1279" = (*)(-0.5, var"##1278")
var"##1280" = (*)(-10516.391601146239, β)
var"##1281" = (*)(9926.088553874832, var"##1278")
var"##1282" = (+)(4397.346085970019, var"##1280", var"##1281")
var"##1283" = (^)(σ, -2)
var"##1284" = (*)(-0.5, var"##1282", var"##1283")
var"##1285" = (+)(var"##1275", var"##1277", var"##1279", var"##1284")
end
Finally we add some code to tell us where to get the variables, and use GeneralizedGenerated.jl
to make it callable:
codegen(post; ℓ=ℓ)
function = (_args, _data, _pars;) -> begin
begin
β = (Main.FD_SANDBOX_8265385389467078856).getproperty(_pars, :β)
σ = (Main.FD_SANDBOX_8265385389467078856).getproperty(_pars, :σ)
y = (Main.FD_SANDBOX_8265385389467078856).getproperty(_data, :y)
λ = (Main.FD_SANDBOX_8265385389467078856).getproperty(_args, :λ)
x = (Main.FD_SANDBOX_8265385389467078856).getproperty(_args, :x)
var"##1374" = (*)(-1.0, σ)
var"##1375" = (log)(σ)
var"##1376" = (*)(-10000, var"##1375")
var"##1377" = (^)(β, 2)
var"##1378" = (*)(-0.5, var"##1377")
var"##1379" = (*)(-10516.391601146239, β)
var"##1380" = (*)(9926.088553874832, var"##1377")
var"##1381" = (+)(4397.346085970019, var"##1379", var"##1380")
var"##1382" = (^)(σ, -2)
var"##1383" = (*)(-0.5, var"##1381", var"##1382")
var"##1384" = (+)(var"##1374", var"##1376", var"##1378", var"##1383")
end
end
Accelerating Inference
After calling symlogdensity
as above, you can manipulate the result however you like using SymbolicUtils.jl. When you're happy with the result, it's easy to use this in inference. Here it is with Tamas Papp's DynamicHMC.jl:
postsample = dynamicHMC(post; ℓ=ℓ)
On my machine, this takes 2.255 seconds. Not bad for MCMC on 10,000 observations.
Limiting Inlining
In the above example, we took every opportunity for constant folding, as long as the result is a scalar. In some cases, that might be too much. For example, we expect to be able to use this approach to also accelerate variational inference, in which case we ought to avoid recompiling every time we change the variational parameters.
To account for this, we have a noinline
switch that allows specification of variables to leave alone. For example,
ℓλ = symlogdensity(post; noinline=(:λ,))
which results in
log(λ) - 10000log(σ) - 0.5(β^2) - λ*σ - 0.5(4397.346085970023 + (9926.08855387485(β^2)) - (10516.391601146232β))*(σ^-2)
When Does This Work?
Cases where symlogdensity
"works" in the sense of "doesn't break" are growing quickly; I expect that with some modest effort we can get to a point where every models gives some result that's at least as efficient as the direct approach.
The great speedups we're seeing in this example come thanks in large part to the normal distribution (for the observations) is an exponential family. This means sufficient statistics are of a fixed dimenionality independent of the number of observations. The most obvious applicability I see is for generalized linear models.
There's still some possibility to get big speedups outside of exponential families by rewriting distributions to use exponential families as building blocks. For example, Student's T distribution can be written as a mixture of normals. The mixture components come from an inverse gamma distribution, so in this case we'd expect to be able to "sum away" the normal components, so what we're left with is in terms of inverse gammas.
HMC Timing
On Twitter, Colin Carroll had a great question:
You mention the timing for DynamicHMC is 2.309 seconds -- do you know what it is without the symbolic simplification? I would expect the log probability does not dominate here, but maybe this also speeds up gradient evaluations?
— Colin Carroll (@colindcarroll) January 26, 2021
[Franklin re-runs the code, so the timings will vary a bit]
before = @elapsed dynamicHMC(post)
after = @elapsed dynamicHMC(post; ℓ=ℓ)
speedup = before / after
@show before, after, speedup
(before, after, speedup) = (26.963195093, 2.257612865, 11.943232389845546)
The result is not nearly as dramatic as we saw for evaluation, but it's still substantial. Like all the best questions, the answer to this one raises plenty more questions, which we'll look into another time.
Related Research
Hakaru, a Haskell-based probabilistic programming language, also has a strong focus on symbolic simplification. Hakaru is excellent work and is much more ambitious in the available transformations, but for our purposes we find the advantages of the Julia language and ecosystem too great to step away from.
Avi Bryant's Scala-based Rainier system is very similar in its goals and the available rewrites, and is more mature than this work. There are likely differences in both extensibility and performance, though that's not yet clear. We'll need to consider this in future work.
Final Thoughts
This work is still in relatively early stages, but I think there's a huge potential. If you think this can be helpful for your work, please get in touch!