© 2020, The PyMC Development Team
In this post we’d like to make a major announcement about where PyMC is headed, how we got here, and what our reasons for this direction are.
TL;DR: PyMC3 on Theano with the new JAX backend is the future, PyMC4 based on TensorFlow Probability will not be developed further.
In 2017, the original authors of Theano announced that they would stop development of their excellent library. This left PyMC3, which relies on Theano as its computational backend, in a difficult position and prompted us to start work on PyMC4 which is based on TensorFlow instead. Through this process, we learned that building an interactive probabilistic programming library in TF was not as easy as we thought (more on that below).
In parallel to this, in an effort to extend the life of PyMC3, we took over maintenance of Theano from the Mila team, hosted under Theano-PyMC. Working with the Theano code base, we realized that everything we needed was already present. Moreover, we saw that we could extend the code base in promising ways, such as by adding support for new execution backends like JAX.
Here is the idea: Theano builds up a static computational graph of operations (“Ops”) to perform in sequence. This graph structure is very useful for many reasons: you can do optimizations by fusing computations or replace certain operations with alternatives that are numerically more stable. Critically, you can then take that graph and compile it to different execution backends. By default, Theano supports two execution backends (i.e. implementations for Ops): Python and C.
The Python backend is understandably slow as it just runs your graph using mostly NumPy functions chained together. Thus for speed, Theano relies on its C backend (mostly implemented in CPython). After graph transformation and simplification, the resulting Ops get compiled into their appropriate C analogues and then the resulting C-source files are compiled to a shared library, which is then called by Python.
While this is quite fast, maintaining this C-backend is quite a burden. More importantly, however, it cuts Theano off from all the amazing developments in compiler technology (e.g. XLA) and processor architecture (e.g. TPUs) as we would have to hand-write C-code for those too.
The solution to this problem turned out to be relatively straightforward: compile the Theano graph to other modern tensor computation libraries. We just need to provide JAX implementations for each Theano Ops. Since JAX shares almost an identical API with NumPy/SciPy this turned out to be surprisingly simple, and we had a working prototype within a few days.
Currently, most PyMC3 models already work with the current master branch of Theano-PyMC using our NUTS and SMC samplers. In our limited experiments on small models, the C-backend is still a bit faster than the JAX one, but we anticipate further improvements in performance. To take full advantage of JAX, we need to convert the sampling functions into JAX-jittable functions as well.
This is where things become really interesting. We first compile a PyMC3 model to JAX using the new JAX linker in Theano. We can then take the resulting JAX-graph (at this point there is no more Theano or PyMC3 specific code present, just a JAX function that computes a logp of a model) and pass it to existing JAX implementations of other MCMC samplers found in TFP and NumPyro.
The result: the sampler and model are together fully compiled into a unified JAX graph that can be executed on CPU, GPU, or TPU. The speed in these first experiments is incredible and totally blows our Python-based samplers out of the water.
With the ability to compile Theano graphs to JAX and the availability of JAX-based MCMC samplers, we are at the cusp of a major transformation of PyMC3. Without any changes to the PyMC3 code base, we can switch our backend to JAX and use external JAX-based samplers for lightning-fast sampling of small-to-huge models.
The coolest part is that you, as a user, won’t have to change anything on your existing PyMC3 model code in order to run your models on a modern backend, modern hardware, and JAX-ified samplers, and get amazing speed-ups for free.
In addition, with PyTorch and TF being focused on dynamic graphs, there is currently no other good static graph library in Python. Static graphs, however, have many advantages over dynamic graphs. We thus believe that Theano will have a bright future ahead of itself as a mature, powerful library with an accessible graph representation that can be modified in all kinds of interesting ways and executed on various modern backends.
PyMC4, which is based on TensorFlow, will not be developed further. It was a very interesting and worthwhile experiment that let us learn a lot, but the main obstacle was TensorFlow’s eager mode, along with a variety of technical issues that we could not resolve ourselves. In probabilistic programming, having a static graph of the global state which you can compile and modify is a great strength, as we explained above; Theano is the perfect library for this.
We would like to express our gratitude to users and developers during our exploration of PyMC4. Especially to all GSoC students who contributed features and bug fixes to the libraries, and explored what could be done in a functional modeling approach. We also would like to thank Rif A. Saurous and the Tensorflow Probability Team, who sponsored us two developer summits, with many fruitful discussions. We believe that these efforts will not be lost and it provides us insight to building a better PPL. We are looking forward to incorporating these ideas into future versions of PyMC3.
We want your help
This is a really exciting time for PyMC3 and Theano. If you want to have an impact, this is the perfect time to get involved. You can check out the low-hanging fruit on the Theano and PyMC3 repos. We look forward to your pull requests.
If you are looking for professional help with Bayesian modeling, we recently launched a PyMC3 consultancy, get in touch at email@example.com.