Walrus
Synced from: FhG-IISB/jax_walrus
Walrus-JAX
Note: This package is designed to be used with jNO.
Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.
A JAX/Flax translation of the Walrus PDE foundation model, maintaining exact 1-to-1 weight compatibility with the original PyTorch implementation for pretrained checkpoint conversion.
Overview
Walrus is a 1.29 billion parameter foundation model for partial differential equations (PDEs), trained on The Well — a large-scale collection of PDE simulation datasets. This repository provides a pure JAX/Flax reimplementation of the full model architecture.
Architecture
The model follows an isotropic encoder-processor-decoder design:
Input (T, B, C, H, W, D)
│
▼
┌──────────────────┐
│ SpaceBag Encoder │ Variable-stride 3D conv with sparse channel embedding
│ (embed_2/embed_3)│ Two conv layers: input→352→1408, RMSGroupNorm + SiLU
└────────┬─────────┘
│ (T, B, 1408, H', W', D')
▼
┌──────────────────┐
│ 40 Processor │ Each block:
│ Blocks │ 1. AxialTimeAttention (temporal, 16 heads, T5 rel-pos bias)
│ (SpaceTimeSplit) │ 2. FullAttention (spatial, 16 heads, 3D RoPE, SwiGLU)
└────────┬─────────┘
│
▼
┌──────────────────┐
│ Decoder │ Transposed conv with periodic BC handling
│ (debed_2/debed_3)│ Two conv layers: 1408→352→output, channel selection
└────────┬─────────┘
│
▼
Output (1, B, C_out, H, W, D)
Key Model Parameters (Pretrained Config)
| Parameter | Value |
|---|---|
| Hidden dimension | 1408 |
| Intermediate dimension | 352 |
| Processor blocks | 40 |
| Attention heads | 16 |
| Head dimension | 88 |
| Groups (normalization) | 16 |
| Total states | 67 |
| Base kernel size | ((8,4), (8,4), (8,4)) |
| Causal in time | Yes |
| Temporal bias | T5-style relative |
| Spatial RoPE | Lucidrains axial 3D |
| Activation | SiLU |
| Total parameters | 1.29B |
Installation
# Using uv (recommended)
uv venv
uv pip install -e .
# For GPU support (CUDA 12)
uv pip install -e ".[gpu]"
# For weight conversion from PyTorch
uv pip install -e ".[convert]"
# For running equivalence tests
uv pip install -e ".[test]"
Weight Conversion
Convert a pretrained PyTorch checkpoint to JAX msgpack format:
The script:
1. Loads the PyTorch checkpoint (ckpt['app']['model'] format)
2. Maps all 857 parameters to the Flax parameter tree
3. Validates that every PyTorch key is mapped
4. Saves as msgpack with roundtrip verification
Weight Mapping Rules
| PyTorch | Flax | Transformation |
|---|---|---|
nn.Linear.weight |
.kernel |
Transposed (.T) |
nn.LayerNorm.weight |
.scale |
As-is |
nn.Embedding.weight |
.embedding |
As-is |
RMSGroupNorm.weight |
.weight |
As-is |
| Conv weights | proj{1,2}_weight |
As-is (no transpose) |
embed.{d}.* |
embed_{d}.* |
ModuleDict → named params |
blocks.{i}.* |
blocks_{i}.* |
ModuleList → named params |
Usage
Loading Converted Weights
import jax.numpy as jnp
from flax.serialization import from_bytes
from jax_walrus import IsotropicModel
# Create model with pretrained config
model = IsotropicModel(
hidden_dim=1408,
intermediate_dim=352,
n_states=67,
processor_blocks=40,
num_heads=16,
groups=16,
causal_in_time=True,
bias_type="rel",
base_kernel_size=((8, 4), (8, 4), (8, 4)),
use_spacebag=True,
use_silu=True,
include_d=(2, 3),
encoder_groups=16,
)
# Load converted weights
with open("jax_walrus.msgpack", "rb") as f:
params = from_bytes(target=None, encoded_bytes=f.read())
# Run inference
# x: (T, B, C, H, W, D) — input PDE state
# state_labels: which output channels to predict
# bcs: boundary conditions per spatial dim
# field_indices: which input channels are present (for SpaceBag)
output = model.apply(
params,
x,
state_labels=jnp.array([0, 1, 2]),
bcs=[[2, 2], [2, 2], [2, 2]], # periodic BCs
stride1=(4, 4, 4),
stride2=(4, 4, 4),
field_indices=jnp.array([0, 1, 2, 3, 4, 5]),
dim_key=3, # 3D data
)
Using Individual Components
from jax_walrus.encoder import SpaceBagAdaptiveDVstrideEncoder
from jax_walrus.decoder import AdaptiveDVstrideDecoder
from jax_walrus.processor import SpaceTimeSplitBlock
from jax_walrus.spatial_attention import FullAttention
from jax_walrus.temporal_attention import AxialTimeAttention
from jax_walrus.normalization import RMSGroupNorm
Project Structure
jax_walrus/
├── jax_walrus/ # Core library
│ ├── __init__.py # Exports IsotropicModel
│ ├── model.py # IsotropicModel (top-level)
│ ├── encoder.py # AdaptiveDVstride + SpaceBag encoders
│ ├── decoder.py # AdaptiveDVstride decoder with periodic BCs
│ ├── processor.py # SpaceTimeSplitBlock
│ ├── spatial_attention.py # FullAttention (SwiGLU, RoPE, QK-norm)
│ ├── temporal_attention.py# AxialTimeAttention (rel-pos bias, causal)
│ ├── normalization.py # RMSGroupNorm
│ ├── rope.py # Rotary embeddings (lucidrains + simple)
│ └── convert_weights.py # PyTorch → Flax param mapping
├── scripts/
│ └── convert_pretrained.py# CLI weight conversion script
├── tests/
│ └── test_equivalence.py # Component-level PT vs JAX tests
├── pyproject.toml
├── .gitignore
└── README.md
Module Details
Encoder (encoder.py)
Two variants of the variable-stride 3D convolution encoder:
AdaptiveDVstrideEncoder— Plain two-layer conv encoder (input→inner→output)SpaceBagAdaptiveDVstrideEncoder— Sparse channel selection viafield_indiceswith magnitude-preserving scaling
Both handle singleton spatial dimensions by summing the kernel over those axes. Uses _conv3d for manual 3D convolution matching PyTorch's F.conv3d.
Decoder (decoder.py)
AdaptiveDVstrideDecoder — Transposed conv decoder with:
- Adaptive stride handling for singleton dims
- Periodic boundary condition support via circular padding (jnp.pad(mode='wrap'))
- Output channel selection via state_labels
Uses _conv_transpose3d with explicit kernel flipping to match PyTorch's F.conv_transpose3d.
Spatial Attention (spatial_attention.py)
FullAttention — Spatial attention block with:
- Fused FF + Q + K + V single-linear projection
- SwiGLU feedforward network
- Axial 3D Rotary Position Embeddings (lucidrains-style)
- QK-norm via LayerNorm
- Attention over flattened H×W×D spatial tokens
Temporal Attention (temporal_attention.py)
AxialTimeAttention — Applied independently at each spatial location:
- 1×1×1 conv input/output heads
- T5-style relative position bias (or rotary, configurable)
- Optional causal masking for autoregressive prediction
- QK-norm
Processor (processor.py)
SpaceTimeSplitBlock — Composes temporal → spatial → channel (identity) mixing.
Normalization (normalization.py)
RMSGroupNorm — RMS normalization per group (no mean subtraction), with learned per-channel scale. Operates in channels-first layout (B, C, *spatial).
RoPE (rope.py)
LRRotaryEmbedding— Lucidrains-style with axial ND frequency gridsSimpleRotaryEmbedding— Standard sinusoidal for temporal attentionRelativePositionBias— T5-style bucketed relative position bias
Equivalence Testing
The equivalence tests verify that each JAX component produces identical outputs to the PyTorch original given the same weights and inputs.
Running Tests
# Set path to original walrus source
export WALRUS_ROOT=/path/to/walrus
# Run component tests (requires PyTorch + walrus source)
python tests/test_equivalence.py
Test Results (with Random Weights)
| Component | Max Diff | Status |
|---|---|---|
| RMSGroupNorm | 2.62e-06 | PASS |
| Encoder | 7.95e-04 | PASS |
| Decoder | 1.67e-04 | PASS |
| FullAttention | 3.58e-07 | PASS |
| AxialTimeAttention | 0.00e+00 | PASS |
Pretrained Weight Comparison
With the full 1.29B parameter pretrained checkpoint:
| Component | Max Diff | Notes |
|---|---|---|
| Encoder | 7.34e-03 | Acceptable — large weight magnitudes |
| Single Block | 3.42e+00 | Drift from float32 precision in large dims |
| Decoder | 1.96e-04 | Excellent match |
The block-level differences arise from floating-point accumulation across the large hidden dimension (1408) and are consistent with expected float32 numerical differences between PyTorch and JAX. All component-level outputs match within tolerance.
Implementation Notes
Key Differences from PyTorch
-
Conv transposed: JAX's
lax.conv_general_dilateddoesn't have a direct transposed conv mode. We implement it via zero-insertion upsampling + spatially flipped kernel + regular convolution. -
Channels-first layout: We keep the PyTorch
(B, C, H, W, D)layout throughout for weight compatibility, transposing to/from JAX's preferred channels-last forlax.conv_general_dilated. -
2D/3D variants: Both
embed_2/embed_3use the same class withspatial_dims=3. The 2D variant simply has a singleton spatial dimension that is handled via kernel averaging. -
SpaceBag scaling: The scaling uses
weight[:, :-2](not:-extra_dims) to match the PyTorch implementation exactly. -
encoder_dummy: An unused parameter kept in the Flax model for 1-to-1 weight mapping completeness.
License
MIT — see LICENSE.