MPP
Synced from: FhG-IISB/jax_mpp
jax_mpp
Note: This package is designed to be used with jNO.
Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.
JAX/Flax translation of the Multiple Physics Pretraining (MPP) AViT model, maintaining exact 1-to-1 weight compatibility with the original PyTorch implementation.
Overview
MPP is a pretraining strategy that jointly normalizes and embeds multiple sets of physical dynamics into a single space for prediction. The model uses an Axial Vision Transformer (AViT) architecture with factored space-time attention.
Input (T, B, C, H, W)
│
▼
┌─── Instance Norm ──────┐
│ normalize per sample │
│ across time + space │
└────────┬───────────────┘
▼
┌─── SubsampledLinear ───┐ Sparse channel projection
│ state vocab → embed/4 │ (selects active state vars)
└────────┬───────────────┘
▼
┌─── hMLP_stem ──────────┐ Hierarchical conv embedding
│ Conv4s4 → Conv2s2 → │ (3 stages, stride 16 total)
│ Conv2s2 each w/ RMS-IN │
└────────┬───────────────┘
▼
┌─── N × SpaceTimeBlock ─┐
│ ┌────────────────────┐ │
│ │ Temporal Attention │ │ Full attention over T axis
│ │ (InstanceNorm, QKV, │ │ with relative position bias
│ │ RPB, LayerScale) │ │
│ └────────┬───────────┘ │
│ ▼ │
│ ┌────────────────────┐ │
│ │ Axial Spatial Attn │ │ X-axis + Y-axis attention
│ │ (RMSInstanceNorm, │ │ averaged, with RPB, MLP,
│ │ QKV, RPB, MLP, │ │ layer scale, drop path
│ │ LayerScale) │ │
│ └────────┬───────────┘ │
└───────────┼─────────────┘
▼
┌─── hMLP_output ────────┐ Hierarchical conv de-embedding
│ ConvT2s2 → ConvT2s2 → │ (3 stages, stride 16 total)
│ ConvT4s4 (subsampled) │
└────────┬───────────────┘
▼
┌─── De-normalise ───────┐
│ x * std + mean │
└────────┬───────────────┘
▼
Output (B, C, H, W) [last time step]
Model Variants
| Variant | embed_dim | heads | blocks | Params (approx.) |
|---|---|---|---|---|
| Ti | 192 | 3 | 12 | ~5.5 M |
| S | 384 | 6 | 12 | ~21 M |
| B | 768 | 12 | 12 | ~83 M |
| L | 1024 | 16 | 24 | ~300 M |
All variants use patch_size=(16, 16), n_states=12, bias_type="rel".
Reference
| Paper | Multiple Physics Pretraining for Physical Surrogate Models (NeurIPS 2024) |
| Weights | Google Drive |
| Original code | PolymathicAI/multiple_physics_pretraining |
Installation
uv venv && source .venv/bin/activate
uv pip install -e .
# With GPU support:
uv pip install -e ".[gpu]"
# For weight conversion from PyTorch:
uv pip install -e ".[convert]"
Usage
Quick Start
import jax
import jax.numpy as jnp
from jax_mpp import avit_B
# Create model
model = avit_B(n_states=12)
# Dummy inputs
rng = jax.random.PRNGKey(0)
T, B, C, H, W = 4, 2, 3, 128, 128
x = jnp.ones((T, B, C, H, W))
labels = jnp.array([0, 1, 2])
bcs = jnp.zeros((B, 2), dtype=jnp.int32)
# Initialize parameters
params = model.init(
{"params": rng, "drop_path": rng},
x, labels, bcs, deterministic=True,
)
# Forward pass
y = model.apply(params, x, labels, bcs, deterministic=True)
print(y.shape) # (2, 3, 128, 128)
Loading Pretrained Weights
from jax_mpp import avit_B, load_pytorch_state_dict, convert_pytorch_to_jax_params
# Load and convert PyTorch checkpoint
pt_state_dict = load_pytorch_state_dict("path/to/checkpoint.tar")
jax_params = convert_pytorch_to_jax_params(pt_state_dict)
# Create model and run
model = avit_B(n_states=12)
y = model.apply({"params": jax_params}, x, labels, bcs, deterministic=True)
Explicit Construction
from jax_mpp import AViT
model = AViT(
patch_size=(16, 16),
embed_dim=768,
processor_blocks=12,
n_states=12,
num_heads=12,
drop_path=0.1,
bias_type="rel",
)
Individual Components
from jax_mpp import (
AxialAttentionBlock,
AttentionBlock,
SpaceTimeBlock,
SubsampledLinear,
hMLP_stem,
hMLP_output,
)
Weight Conversion
uv run python scripts/convert.py \
--checkpoint path/to/ckpt.tar \
--output weights.msgpack \
--variant B
Key Mapping Rules
| PyTorch | Flax |
|---|---|
blocks.{i}.* |
blocks_{i}.* |
nn.Linear.weight |
.kernel (transposed) |
nn.Conv2d.weight |
.kernel (OIHW → HWIO) |
nn.LayerNorm.weight |
.scale |
nn.Embedding.weight |
.embedding |
embed.in_proj.{i}.* |
embed.in_proj_{i}.* |
debed.out_proj.{i}.* |
debed.out_proj_{i}.* |
Project Structure
jax_mpp/
├── jax_mpp/
│ ├── __init__.py # Public API + version
│ ├── avit.py # Main AViT model
│ ├── configs.py # Variant configs (Ti/S/B/L) & constructors
│ ├── convert_weights.py # PyTorch → Flax weight mapping
│ ├── mixed_modules.py # SpaceTimeBlock (temporal + spatial)
│ ├── shared_modules.py # Position biases, MLP
│ ├── spatial_modules.py # Patch embed/de-embed, axial attention
│ └── time_modules.py # Temporal attention
├── scripts/
│ ├── convert.py # CLI: PyTorch → msgpack conversion
│ └── compare.py # CLI: equivalence testing
├── pyproject.toml
├── README.md
└── LICENSE
Module Details
| Module | Description |
|---|---|
avit.py |
Top-level AViT: normalize → embed → process → de-embed → denormalize |
configs.py |
Variant definitions (Ti/S/B/L) and avit_*() constructors |
mixed_modules.py |
SpaceTimeBlock — sequential temporal + spatial attention |
spatial_modules.py |
hMLP_stem/hMLP_output (hierarchical conv), AxialAttentionBlock, SubsampledLinear, RMSInstanceNorm2d |
time_modules.py |
AttentionBlock — full attention over the time axis |
shared_modules.py |
RelativePositionBias, ContinuousPositionBias1D, MLP |
convert_weights.py |
PyTorch state_dict → Flax params conversion |
Implementation Notes
- All modules use channels-last (
NHWC) convention internally; the top-levelAViTaccepts the PyTorch-style(T, B, C, H, W)input for API compatibility and rearranges internally. flax.linen(functional API) is used throughout — consistent with the other JAX translation packages.- Stochastic depth (
drop_path) is controlled via thedeterministicflag and the"drop_path"RNG key. - The
SubsampledLinearandhMLP_outputmodules perform dynamic weight indexing based onstate_labels, replicating the original sparse projection mechanism.
License
MIT