Symbolic Simplification

January 25, 2021

Upcoming features in Soss.jl include static model simplification. After a one-time compilation cost, posterior log-densities for many models become "constant cost", independent of the number of observations. Bayesian analysis for such models can easily scale to big data.

The symbolic representation of the posterior log-density can also be useful for pedagogical purposes.

First, a Model

Let's start with something simple, say a linear regression with 10,000 observations.

using Soss

N = 10000

m = @model x, λ begin
    σ ~ Exponential(λ)
    β ~ Normal(0,1) 
    y ~ For(1:N) do j
        Normal(x[j] * β, σ)
    end
    return y
end

Soss models are generative, so we can use one to generate data. This is useful to help make sure the model assumptions are reasonable (fake data should look similar to real data), and it's also handy for little examples like this.

To generate data from m, we need to have values for the input arguments x and λ:

x = randn(N)
λ = 1.0

Given this, we could just call rand(m(x=x,λ=λ)). But this would "forget" the values of σ and β. To instead keep track of these, we can do

trace = simulate(m(x=x, λ=1.0)).trace
y = trace.y

As a result, we now have

trace.σ = 0.3974449483739133
trace.β = 0.5236103489251422
mean(y) = -0.009302852896095896
std(y) = 0.6630927798796098

Evaluating the Log-Density (The Punchline)

In Bayesian modeling, we're usually in a situation of having observed y, and wanting to infer what we can about latent parameters like σ and β. So let's construct the posterior distribution. We could write this as post = m(x=x, λ=λ) | (y=y,), or using the shorthand

post = m(; x, λ) | (; y)

To evaluate this, say on the trace from the original sample, we just call

logdensity(post, trace)

which in this case gives

4123.1269713242

Let's see how fast it is.

using BenchmarkTools
slow = @benchmark $logdensity($post, $trace)
BenchmarkTools.Trial: 
  memory estimate:  1.07 MiB
  allocs estimate:  60032
  --------------
  minimum time:     2.879 ms (0.00% GC)
  median time:      3.002 ms (0.00% GC)
  mean time:        3.246 ms (4.65% GC)
  maximum time:     18.201 ms (81.89% GC)
  --------------
  samples:          1540
  evals/sample:     1

Ok, clearly some things can be improved, for example we should be able to do this without allocation.

But rather than dwell too much on that, let's jump to the punchline:

ℓ = symlogdensity(post)
f = codegen(post; ℓ=ℓ)
fast = @benchmark $f($(;x,λ), $(;y), $trace)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     84.730 ns (0.00% GC)
  median time:      85.937 ns (0.00% GC)
  mean time:        86.678 ns (0.00% GC)
  maximum time:     139.092 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     963

Note the units; we started out in milliseconds, but now we're in nanoseconds. A quick minimum(slow).time / minimum(fast).time shows we've gotten a 33980× speedup.

How it Works

is a symbolic expression from SymbolicUtils.jl. To build it, we first generate a trace, as we did above. Now we have all of the types, so we can evaluate the log-density again, this time replacing the values with symbolic variables of the appropriate type.

At this point, we have a symbolic expression, and we also have a dictionary of fixed values from the arguments and observed values. So we expand the expression and then walk the result, looking for subexpressions that evaluate to scalars with no free variables. When we find one, we use the dictionary to rewrite it.

Next, we do common subexpression elimination, which just makes sure we don't evaluate the same thing twice.

Soss.cse(ℓ)
11-element Vector{Pair{Symbol, SymbolicUtils.Symbolic}}:
 Symbol("##1212") => -σ
 Symbol("##1213") => log(σ)
 Symbol("##1214") => -10000var"##1213"
 Symbol("##1215") => β^2
 Symbol("##1216") => -0.5var"##1215"
 Symbol("##1217") => -10516.391601146239β
 Symbol("##1218") => 9926.088553874832var"##1215"
 Symbol("##1219") => 4397.346085970019 + var"##1217" + var"##1218"
 Symbol("##1220") => σ^-2
 Symbol("##1221") => -0.5var"##1219"*var"##1220"
 Symbol("##1222") => var"##1212" + var"##1214" + var"##1216" + var"##1221"

It's a very short step to get from that to a Julia Expr that does what we want:

sourceCodegen(post)
quote
    var"##1275" = (*)(-1.0, σ)
    var"##1276" = (log)(σ)
    var"##1277" = (*)(-10000, var"##1276")
    var"##1278" = (^)(β, 2)
    var"##1279" = (*)(-0.5, var"##1278")
    var"##1280" = (*)(-10516.391601146239, β)
    var"##1281" = (*)(9926.088553874832, var"##1278")
    var"##1282" = (+)(4397.346085970019, var"##1280", var"##1281")
    var"##1283" = (^)(σ, -2)
    var"##1284" = (*)(-0.5, var"##1282", var"##1283")
    var"##1285" = (+)(var"##1275", var"##1277", var"##1279", var"##1284")
end

Finally we add some code to tell us where to get the variables, and use GeneralizedGenerated.jl to make it callable:

codegen(post; ℓ=ℓ)
function = (_args, _data, _pars;) -> begin
    begin
        β = (Main.FD_SANDBOX_8265385389467078856).getproperty(_pars, :β)
        σ = (Main.FD_SANDBOX_8265385389467078856).getproperty(_pars, :σ)
        y = (Main.FD_SANDBOX_8265385389467078856).getproperty(_data, :y)
        λ = (Main.FD_SANDBOX_8265385389467078856).getproperty(_args, :λ)
        x = (Main.FD_SANDBOX_8265385389467078856).getproperty(_args, :x)
        var"##1374" = (*)(-1.0, σ)
        var"##1375" = (log)(σ)
        var"##1376" = (*)(-10000, var"##1375")
        var"##1377" = (^)(β, 2)
        var"##1378" = (*)(-0.5, var"##1377")
        var"##1379" = (*)(-10516.391601146239, β)
        var"##1380" = (*)(9926.088553874832, var"##1377")
        var"##1381" = (+)(4397.346085970019, var"##1379", var"##1380")
        var"##1382" = (^)(σ, -2)
        var"##1383" = (*)(-0.5, var"##1381", var"##1382")
        var"##1384" = (+)(var"##1374", var"##1376", var"##1378", var"##1383")
    end
end

Accelerating Inference

After calling symlogdensity as above, you can manipulate the result however you like using SymbolicUtils.jl. When you're happy with the result, it's easy to use this in inference. Here it is with Tamas Papp's DynamicHMC.jl:

postsample = dynamicHMC(post; ℓ=ℓ)

On my machine, this takes 2.255 seconds. Not bad for MCMC on 10,000 observations.

Limiting Inlining

In the above example, we took every opportunity for constant folding, as long as the result is a scalar. In some cases, that might be too much. For example, we expect to be able to use this approach to also accelerate variational inference, in which case we ought to avoid recompiling every time we change the variational parameters.

To account for this, we have a noinline switch that allows specification of variables to leave alone. For example,

ℓλ = symlogdensity(post; noinline=(:λ,))

which results in

log(λ) - 10000log(σ) - 0.5(β^2) - λ*σ - 0.5(4397.346085970023 + (9926.08855387485(β^2)) - (10516.391601146232β))*(σ^-2)

When Does This Work?

Cases where symlogdensity "works" in the sense of "doesn't break" are growing quickly; I expect that with some modest effort we can get to a point where every models gives some result that's at least as efficient as the direct approach.

The great speedups we're seeing in this example come thanks in large part to the normal distribution (for the observations) is an exponential family. This means sufficient statistics are of a fixed dimenionality independent of the number of observations. The most obvious applicability I see is for generalized linear models.

There's still some possibility to get big speedups outside of exponential families by rewriting distributions to use exponential families as building blocks. For example, Student's T distribution can be written as a mixture of normals. The mixture components come from an inverse gamma distribution, so in this case we'd expect to be able to "sum away" the normal components, so what we're left with is in terms of inverse gammas.

HMC Timing

On Twitter, Colin Carroll had a great question:

[Franklin re-runs the code, so the timings will vary a bit]

before = @elapsed dynamicHMC(post)
after = @elapsed dynamicHMC(post; ℓ=ℓ)

speedup = before / after

@show before, after, speedup
(before, after, speedup) = (26.963195093, 2.257612865, 11.943232389845546)

The result is not nearly as dramatic as we saw for evaluation, but it's still substantial. Like all the best questions, the answer to this one raises plenty more questions, which we'll look into another time.

Hakaru, a Haskell-based probabilistic programming language, also has a strong focus on symbolic simplification. Hakaru is excellent work and is much more ambitious in the available transformations, but for our purposes we find the advantages of the Julia language and ecosystem too great to step away from.

Avi Bryant's Scala-based Rainier system is very similar in its goals and the available rewrites, and is more mature than this work. There are likely differences in both extensibility and performance, though that's not yet clear. We'll need to consider this in future work.

Final Thoughts

This work is still in relatively early stages, but I think there's a huge potential. If you think this can be helpful for your work, please get in touch!