This is where things become really interesting. We first compile a PyMC3 model to JAX using the new JAX linker in Theano. We can then take the resulting JAX-graph (at this point there is no more Theano or PyMC3 specific code present, just a JAX function that computes a logp of a model) and pass it to existing JAX implementations of other MCMC samplers found in TFP and NumPyro.

With the ability to compile Theano graphs to JAX and the availability of JAX-based MCMC samplers, we are at the cusp of a major transformation of PyMC3. Without any changes to the PyMC3 code base, we can switch our backend to JAX and use external JAX-based samplers for lightning-fast sampling of small-to-huge models.

PyMC4, which is based on TensorFlow, will not be developed further. It was a very interesting and worthwhile experiment that let us learn a lot, but the main obstacle was TensorFlow’s eager mode, along with a variety of technical issues that we could not resolve ourselves. In probabilistic programming, having a static graph of the global state which you can compile and modify is a great strength, as we explained above; Theano is the perfect library for this.

This is a really exciting time for PyMC3 and Theano. If you want to have an impact, this is the perfect time to get involved. You can check out the low-hanging fruit on the Theano and PyMC3 repos. We look forward to your pull requests.



