Installation
Requires Python 3.11–3.13.
PyPI
GPU support is included by default — jax-neural-operators depends on
jax[cuda]>=0.10.1,<0.11, so a standard pip install already pulls a
CUDA-capable JAX wheel. The pin is tight on purpose: jNO tracks a single
JAX minor version per release to avoid silently breaking on JAX API
changes.
If you need a specific CUDA toolkit version, install JAX from its own package index before installing jNO:
# CUDA 12 example
pip install --upgrade "jax[cuda12]>=0.10.1,<0.11"
pip install "jax-neural-operators[fem]"
To pin a different JAX version locally for an experiment, see the JAX install matrix.
Clone + Pixi
For development or to run examples from source. Requires pixi.
Common tasks:
pixi run fmt # format with ruff
pixi run lint # lint and auto-fix
pixi run test # run the test suite
Docker
CPU:
GPU (requires NVIDIA drivers and NVIDIA Container Toolkit):
Build locally: