The Future of PyMC3, or: Theano is Dead, Long Live Theano

PyMC + Theano + JAX

JAX-based Samplers

This is where things become really interesting. We first compile a PyMC3 model to JAX using the new JAX linker in Theano. We can then take the resulting JAX-graph (at this point there is no more Theano or PyMC3 specific code present, just a JAX function that computes a logp of a model) and pass it to existing JAX implementations of other MCMC samplers found in TFP and NumPyro.

The Future

With the ability to compile Theano graphs to JAX and the availability of JAX-based MCMC samplers, we are at the cusp of a major transformation of PyMC3. Without any changes to the PyMC3 code base, we can switch our backend to JAX and use external JAX-based samplers for lightning-fast sampling of small-to-huge models.

Whither PyMC4?

PyMC4, which is based on TensorFlow, will not be developed further. It was a very interesting and worthwhile experiment that let us learn a lot, but the main obstacle was TensorFlow’s eager mode, along with a variety of technical issues that we could not resolve ourselves. In probabilistic programming, having a static graph of the global state which you can compile and modify is a great strength, as we explained above; Theano is the perfect library for this.

We want your help

This is a really exciting time for PyMC3 and Theano. If you want to have an impact, this is the perfect time to get involved. You can check out the low-hanging fruit on the Theano and PyMC3 repos. We look forward to your pull requests.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
PyMC Developers

PyMC Developers

Probabilistic Programming in Python