vijax
A simple JAX library to easily run VI
vijax is a flexible and modular library for variational inference (VI) with the aim to make VI algorithms more accessible. It uses JAX, numpyro, tensorflow-probability, and other related libraries under the hood. I build this library as part of my AISTATS 2025 paper. Please check out the repository for full details. Read below for a brief overview.
Design philosophy
The library is designed to be accessible to a wide range of users and applications, enabling them to perform VI without getting tied down by specific abstractions. Three are three main components of vijax
:
-
Model
: Represents a probability model $p(z,x)$, where $z$ are latent variables and $x$ are observed data. -
VarDist
: Represents a variational distribution $q_w(z)$. -
Recipe
: Provides a high-level interface for running pre-defined VI algorithms.
By adhering to these abstractions, users can easily plug in their own optimization routines, models, and variational distributions, while still benefiting from the core features and utilities provided by the vijax
library.
For more information, see the repository.
Installation
pip install vijax
Citing
If you use vijax, please consider citing:
@inproceedings{agrawal2024disentangling,
title={Disentangling impact of capacity, objective, batchsize, estimators, and step-size on flow VI},
author={Agrawal, Abhinav and Domke, Justin},
booktitle={AISTATS},
year={2025},
}