Skip to content

Walrus

Synced from: FhG-IISB/jax_walrus

Walrus-JAX

Note: This package is designed to be used with jNO.

Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.

A JAX/Flax translation of the Walrus PDE foundation model, maintaining exact 1-to-1 weight compatibility with the original PyTorch implementation for pretrained checkpoint conversion.

Overview

Walrus is a 1.29 billion parameter foundation model for partial differential equations (PDEs), trained on The Well — a large-scale collection of PDE simulation datasets. This repository provides a pure JAX/Flax reimplementation of the full model architecture.

Architecture

The model follows an isotropic encoder-processor-decoder design:

Input (T, B, C, H, W, D)
┌──────────────────┐
│  SpaceBag Encoder │  Variable-stride 3D conv with sparse channel embedding
│  (embed_2/embed_3)│  Two conv layers: input→352→1408, RMSGroupNorm + SiLU
└────────┬─────────┘
         │  (T, B, 1408, H', W', D')
┌──────────────────┐
│  40 Processor     │  Each block:
│  Blocks           │    1. AxialTimeAttention (temporal, 16 heads, T5 rel-pos bias)
│  (SpaceTimeSplit) │    2. FullAttention (spatial, 16 heads, 3D RoPE, SwiGLU)
└────────┬─────────┘
┌──────────────────┐
│  Decoder          │  Transposed conv with periodic BC handling
│  (debed_2/debed_3)│  Two conv layers: 1408→352→output, channel selection
└────────┬─────────┘
Output (1, B, C_out, H, W, D)

Key Model Parameters (Pretrained Config)

Parameter Value
Hidden dimension 1408
Intermediate dimension 352
Processor blocks 40
Attention heads 16
Head dimension 88
Groups (normalization) 16
Total states 67
Base kernel size ((8,4), (8,4), (8,4))
Causal in time Yes
Temporal bias T5-style relative
Spatial RoPE Lucidrains axial 3D
Activation SiLU
Total parameters 1.29B

Installation

# Using uv (recommended)
uv venv
uv pip install -e .

# For GPU support (CUDA 12)
uv pip install -e ".[gpu]"

# For weight conversion from PyTorch
uv pip install -e ".[convert]"

# For running equivalence tests
uv pip install -e ".[test]"

Weight Conversion

Convert a pretrained PyTorch checkpoint to JAX msgpack format:

python scripts/convert_pretrained.py --input walrus.pt --output jax_walrus.msgpack

The script: 1. Loads the PyTorch checkpoint (ckpt['app']['model'] format) 2. Maps all 857 parameters to the Flax parameter tree 3. Validates that every PyTorch key is mapped 4. Saves as msgpack with roundtrip verification

Weight Mapping Rules

PyTorch Flax Transformation
nn.Linear.weight .kernel Transposed (.T)
nn.LayerNorm.weight .scale As-is
nn.Embedding.weight .embedding As-is
RMSGroupNorm.weight .weight As-is
Conv weights proj{1,2}_weight As-is (no transpose)
embed.{d}.* embed_{d}.* ModuleDict → named params
blocks.{i}.* blocks_{i}.* ModuleList → named params

Usage

Loading Converted Weights

import jax.numpy as jnp
from flax.serialization import from_bytes
from jax_walrus import IsotropicModel

# Create model with pretrained config
model = IsotropicModel(
    hidden_dim=1408,
    intermediate_dim=352,
    n_states=67,
    processor_blocks=40,
    num_heads=16,
    groups=16,
    causal_in_time=True,
    bias_type="rel",
    base_kernel_size=((8, 4), (8, 4), (8, 4)),
    use_spacebag=True,
    use_silu=True,
    include_d=(2, 3),
    encoder_groups=16,
)

# Load converted weights
with open("jax_walrus.msgpack", "rb") as f:
    params = from_bytes(target=None, encoded_bytes=f.read())

# Run inference
# x: (T, B, C, H, W, D) — input PDE state
# state_labels: which output channels to predict
# bcs: boundary conditions per spatial dim
# field_indices: which input channels are present (for SpaceBag)
output = model.apply(
    params,
    x,
    state_labels=jnp.array([0, 1, 2]),
    bcs=[[2, 2], [2, 2], [2, 2]],  # periodic BCs
    stride1=(4, 4, 4),
    stride2=(4, 4, 4),
    field_indices=jnp.array([0, 1, 2, 3, 4, 5]),
    dim_key=3,  # 3D data
)

Using Individual Components

from jax_walrus.encoder import SpaceBagAdaptiveDVstrideEncoder
from jax_walrus.decoder import AdaptiveDVstrideDecoder
from jax_walrus.processor import SpaceTimeSplitBlock
from jax_walrus.spatial_attention import FullAttention
from jax_walrus.temporal_attention import AxialTimeAttention
from jax_walrus.normalization import RMSGroupNorm

Project Structure

jax_walrus/
├── jax_walrus/              # Core library
│   ├── __init__.py          # Exports IsotropicModel
│   ├── model.py             # IsotropicModel (top-level)
│   ├── encoder.py           # AdaptiveDVstride + SpaceBag encoders
│   ├── decoder.py           # AdaptiveDVstride decoder with periodic BCs
│   ├── processor.py         # SpaceTimeSplitBlock
│   ├── spatial_attention.py # FullAttention (SwiGLU, RoPE, QK-norm)
│   ├── temporal_attention.py# AxialTimeAttention (rel-pos bias, causal)
│   ├── normalization.py     # RMSGroupNorm
│   ├── rope.py              # Rotary embeddings (lucidrains + simple)
│   └── convert_weights.py   # PyTorch → Flax param mapping
├── scripts/
│   └── convert_pretrained.py# CLI weight conversion script
├── tests/
│   └── test_equivalence.py  # Component-level PT vs JAX tests
├── pyproject.toml
├── .gitignore
└── README.md

Module Details

Encoder (encoder.py)

Two variants of the variable-stride 3D convolution encoder:

  • AdaptiveDVstrideEncoder — Plain two-layer conv encoder (input→inner→output)
  • SpaceBagAdaptiveDVstrideEncoder — Sparse channel selection via field_indices with magnitude-preserving scaling

Both handle singleton spatial dimensions by summing the kernel over those axes. Uses _conv3d for manual 3D convolution matching PyTorch's F.conv3d.

Decoder (decoder.py)

AdaptiveDVstrideDecoder — Transposed conv decoder with: - Adaptive stride handling for singleton dims - Periodic boundary condition support via circular padding (jnp.pad(mode='wrap')) - Output channel selection via state_labels

Uses _conv_transpose3d with explicit kernel flipping to match PyTorch's F.conv_transpose3d.

Spatial Attention (spatial_attention.py)

FullAttention — Spatial attention block with: - Fused FF + Q + K + V single-linear projection - SwiGLU feedforward network - Axial 3D Rotary Position Embeddings (lucidrains-style) - QK-norm via LayerNorm - Attention over flattened H×W×D spatial tokens

Temporal Attention (temporal_attention.py)

AxialTimeAttention — Applied independently at each spatial location: - 1×1×1 conv input/output heads - T5-style relative position bias (or rotary, configurable) - Optional causal masking for autoregressive prediction - QK-norm

Processor (processor.py)

SpaceTimeSplitBlock — Composes temporal → spatial → channel (identity) mixing.

Normalization (normalization.py)

RMSGroupNorm — RMS normalization per group (no mean subtraction), with learned per-channel scale. Operates in channels-first layout (B, C, *spatial).

RoPE (rope.py)

  • LRRotaryEmbedding — Lucidrains-style with axial ND frequency grids
  • SimpleRotaryEmbedding — Standard sinusoidal for temporal attention
  • RelativePositionBias — T5-style bucketed relative position bias

Equivalence Testing

The equivalence tests verify that each JAX component produces identical outputs to the PyTorch original given the same weights and inputs.

Running Tests

# Set path to original walrus source
export WALRUS_ROOT=/path/to/walrus

# Run component tests (requires PyTorch + walrus source)
python tests/test_equivalence.py

Test Results (with Random Weights)

Component Max Diff Status
RMSGroupNorm 2.62e-06 PASS
Encoder 7.95e-04 PASS
Decoder 1.67e-04 PASS
FullAttention 3.58e-07 PASS
AxialTimeAttention 0.00e+00 PASS

Pretrained Weight Comparison

With the full 1.29B parameter pretrained checkpoint:

Component Max Diff Notes
Encoder 7.34e-03 Acceptable — large weight magnitudes
Single Block 3.42e+00 Drift from float32 precision in large dims
Decoder 1.96e-04 Excellent match

The block-level differences arise from floating-point accumulation across the large hidden dimension (1408) and are consistent with expected float32 numerical differences between PyTorch and JAX. All component-level outputs match within tolerance.

Implementation Notes

Key Differences from PyTorch

  1. Conv transposed: JAX's lax.conv_general_dilated doesn't have a direct transposed conv mode. We implement it via zero-insertion upsampling + spatially flipped kernel + regular convolution.

  2. Channels-first layout: We keep the PyTorch (B, C, H, W, D) layout throughout for weight compatibility, transposing to/from JAX's preferred channels-last for lax.conv_general_dilated.

  3. 2D/3D variants: Both embed_2/embed_3 use the same class with spatial_dims=3. The 2D variant simply has a singleton spatial dimension that is handled via kernel averaging.

  4. SpaceBag scaling: The scaling uses weight[:, :-2] (not :-extra_dims) to match the PyTorch implementation exactly.

  5. encoder_dummy: An unused parameter kept in the Flax model for 1-to-1 weight mapping completeness.

License

MIT — see LICENSE.