The Future of PyMC3, or: Theano is Dead, Long Live Theano

© 2020, The PyMC Development Team

PyMC + Theano + JAX

In this post we’d like to make a major announcement about where PyMC is headed, how we got here, and what our reasons for this direction are.

TL;DR: PyMC3 on Theano with the new JAX backend is the future, PyMC4 based on TensorFlow Probability will not be developed further.

In 2017, the original authors of Theano announced that they would stop development of their excellent library. This left PyMC3, which relies on Theano as its computational backend, in a difficult position and prompted us to start work on PyMC4 which is based on TensorFlow instead. Through this process, we learned that building an interactive probabilistic programming library in TF was not as easy as we thought (more on that below).

In parallel to this, in an effort to extend the life of PyMC3, we took over maintenance of Theano from the Mila team, hosted under Theano-PyMC. Working with the Theano code base, we realized that everything we needed was already present. Moreover, we saw that we could extend the code base in promising ways, such as by adding support for new execution backends like JAX.

Here is the idea: Theano builds up a static computational graph of operations (“Ops”) to perform in sequence. This graph structure is very useful for many reasons: you can do optimizations by fusing computations or replace certain operations with alternatives that are numerically more stable. Critically, you can then take that graph and compile it to different execution backends. By default, Theano supports two execution backends (i.e. implementations for Ops): Python and C.

The Python backend is understandably slow as it just runs your graph using mostly NumPy functions chained together. Thus for speed, Theano relies on its C backend (mostly implemented in CPython). After graph transformation and simplification, the resulting Ops get compiled into their appropriate C analogues and then the resulting C-source files are compiled to a shared library, which is then called by Python.

While this is quite fast, maintaining this C-backend is quite a burden. More importantly, however, it cuts Theano off from all the amazing developments in compiler technology (e.g. XLA) and processor architecture (e.g. TPUs) as we would have to hand-write C-code for those too.

The solution to this problem turned out to be relatively straightforward: compile the Theano graph to other modern tensor computation libraries. We just need to provide JAX implementations for each Theano Ops. Since JAX shares almost an identical API with NumPy/SciPy this turned out to be surprisingly simple, and we had a working prototype within a few days.

Currently, most PyMC3 models already work with the current master branch of Theano-PyMC using our NUTS and SMC samplers. In our limited experiments on small models, the C-backend is still a bit faster than the JAX one, but we anticipate further improvements in performance. To take full advantage of JAX, we need to convert the sampling functions into JAX-jittable functions as well.

JAX-based Samplers

The result: the sampler and model are together fully compiled into a unified JAX graph that can be executed on CPU, GPU, or TPU. The speed in these first experiments is incredible and totally blows our Python-based samplers out of the water.

The Future

The coolest part is that you, as a user, won’t have to change anything on your existing PyMC3 model code in order to run your models on a modern backend, modern hardware, and JAX-ified samplers, and get amazing speed-ups for free.

In addition, with PyTorch and TF being focused on dynamic graphs, there is currently no other good static graph library in Python. Static graphs, however, have many advantages over dynamic graphs. We thus believe that Theano will have a bright future ahead of itself as a mature, powerful library with an accessible graph representation that can be modified in all kinds of interesting ways and executed on various modern backends.

Whither PyMC4?

We would like to express our gratitude to users and developers during our exploration of PyMC4. Especially to all GSoC students who contributed features and bug fixes to the libraries, and explored what could be done in a functional modeling approach. We also would like to thank Rif A. Saurous and the Tensorflow Probability Team, who sponsored us two developer summits, with many fruitful discussions. We believe that these efforts will not be lost and it provides us insight to building a better PPL. We are looking forward to incorporating these ideas into future versions of PyMC3.

We want your help

If you are looking for professional help with Bayesian modeling, we recently launched a PyMC3 consultancy, get in touch at thomas.wiecki@pymc-labs.io.

Probabilistic Programming in Python

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store