Morph
Synced from: FhG-IISB/jax_morph
jax_morph
Note: This package is designed to be used with jNO.
Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.
JAX/Flax translation of the MORPH PDE foundation model, maintaining exact 1-to-1 weight compatibility with the original PyTorch implementation for pretrained checkpoint conversion.
Overview
MORPH is a family of PDE foundation models from Los Alamos National Laboratory, trained on multi-physics simulation data. This repository provides a pure JAX/Flax reimplementation of the full model architecture (ViT3DRegression).
Architecture
Input (B, t, F, C, D, H, W)
│
▼
┌───────────────────────┐
│ HybridPatchEmbedding │ ConvOperator → Patchify → Dense → FieldCrossAttention
│ (patch_embedding) │ Conv3D feature extraction + learned cross-attention
└──────────┬────────────┘
│ (B, t, n_patches, dim)
▼
┌───────────────────────┐
│ Positional Encoding │ Learned embedding table, interpolated to (t, n_patches)
│ (pos_encoding) │ Ti/S/M: 1D linear | L: 2D bilinear with antialias
└──────────┬────────────┘
│
▼
┌───────────────────────┐
│ N × EncoderBlock │ Each block:
│ (transformer_blocks) │ 1. LayerNorm → AxialAttention (t, d, h, w axes)
│ │ 2. LayerNorm → MLP (with LoRA support)
└──────────┬────────────┘
│
▼
┌───────────────────────┐
│ SimpleDecoder │ LayerNorm → Dense → Unpatchify
│ (decoder) │
└──────────┬────────────┘
│
▼
Output (B, F, C, D, H, W)
Model Variants
| Variant | Dim | Depth | Heads | MLP Dim | max_ar | Parameters |
|---|---|---|---|---|---|---|
| Ti | 256 | 4 | 4 | 1,024 | 1 | 9.9M |
| S | 512 | 4 | 8 | 2,048 | 1 | 32.8M |
| M | 768 | 8 | 12 | 3,072 | 1 | 125.6M |
| L | 1024 | 16 | 16 | 4,096 | 16 | 483.3M |
All variants use conv_filter=8, patch_size=8, heads_xa=32, max_patches=4096.
Reference
Rautela et al., "MORPH: PDE Foundation Models with Arbitrary Data Modality" (2025) - Paper: https://arxiv.org/abs/2509.21670 - Weights: https://huggingface.co/mahindrautela/MORPH - Code: https://github.com/lanl/MORPH
Installation
# Using uv (recommended)
uv venv
uv pip install -e .
# For GPU support (CUDA 12)
uv pip install -e ".[gpu]"
# For weight conversion from PyTorch
uv pip install -e ".[convert]"
Weight Conversion
Convert a pretrained PyTorch checkpoint to JAX msgpack format:
The script:
1. Loads the PyTorch checkpoint (model_state_dict format)
2. Initialises the JAX model and maps all parameters
3. Validates shapes and reports parameter counts
4. Saves as msgpack with roundtrip verification
Weight Mapping Rules
| PyTorch | Flax | Transformation |
|---|---|---|
nn.Conv3d.weight (O,I,D,H,W) |
nn.Conv.kernel (D,H,W,I,O) |
Transpose (2,3,4,1,0) |
nn.Linear.weight (O,I) |
nn.Dense.kernel (I,O) |
Transpose |
nn.Linear.bias |
nn.Dense.bias |
As-is |
nn.LayerNorm.weight |
nn.LayerNorm.scale |
As-is |
nn.LayerNorm.bias |
nn.LayerNorm.bias |
As-is |
nn.MultiheadAttention.in_proj_weight (3E,E) |
Q/K/V kernel (E,H,d) |
Split + transpose + reshape |
nn.MultiheadAttention.in_proj_bias (3E,) |
Q/K/V bias (H,d) |
Split + reshape |
nn.MultiheadAttention.out_proj.* |
out_proj.kernel/bias |
Transpose |
nn.Parameter |
param |
As-is |
LoRALinear.A (R,I) |
A (I,R) |
Transpose |
LoRALinear.B (O,R) |
B (R,O) |
Transpose |
Usage
Using Convenience Constructors
import jax
import jax.numpy as jnp
from jax_morph import morph_Ti, load_pytorch_state_dict, convert_pytorch_to_jax_params
# Create model
model = morph_Ti()
# Initialise
rng = jax.random.PRNGKey(0)
dummy = jnp.zeros((1, 1, 1, 1, 8, 8, 8))
params = model.init(rng, dummy, deterministic=True)
# Convert and load PyTorch weights
pt_sd = load_pytorch_state_dict("morph-Ti-FM-max_ar1_ep225.pth")
params = convert_pytorch_to_jax_params(pt_sd, params)
# Run inference
enc, z, pred = model.apply(params, dummy, deterministic=True)
Loading Converted Weights
import jax.numpy as jnp
from flax.serialization import from_bytes
from jax_morph import morph_Ti
model = morph_Ti()
# Load pre-converted msgpack weights
with open("morph-Ti.msgpack", "rb") as f:
params = from_bytes(target=None, encoded_bytes=f.read())
# Inference: x is (B, t, F, C, D, H, W)
enc, z, pred = model.apply(params, x, deterministic=True)
Using Individual Components
from jax_morph.patch_embedding import HybridPatchEmbedding3D
from jax_morph.encoder_block import EncoderBlock
from jax_morph.axial_attention import AxialAttention3DSpaceTime
from jax_morph.cross_attention import FieldCrossAttention
from jax_morph.decoder import SimpleDecoder
from jax_morph.positional_encoding import PositionalEncodingSLinTSlice
Explicit Model Construction
from jax_morph import ViT3DRegression
model = ViT3DRegression(
patch_size=8, dim=1024, depth=16, heads=16, heads_xa=32,
mlp_dim=4096, max_components=3, conv_filter=8,
max_ar=16, max_patches=4096, max_fields=3,
dropout=0.0, emb_dropout=0.0, model_size="L",
)
Equivalence Testing
Verify that the JAX model with converted weights produces the same output as the PyTorch model:
# Requires cloned MORPH repo (https://github.com/lanl/MORPH)
export MORPH_ROOT=/path/to/MORPH
python scripts/compare.py --model-size Ti
The script downloads the checkpoint from HuggingFace, runs both models on identical random input, and compares outputs.
Results
| Variant | Max Abs Diff | Mean Abs Diff | Parameters Matched | Status |
|---|---|---|---|---|
| Ti | 4.96e-05 | 5.90e-06 | 9,932,920 | PASS |
| S | 1.22e-04 | 2.06e-05 | 32,849,512 | PASS |
| M | 9.77e-04 | 1.00e-04 | 125,574,248 | PASS |
| L | 9.16e-05 | 1.15e-05 | 483,293,400 | PASS |
All models pass with max absolute difference < 1e-3.
Project Structure
jax_morph/
├── jax_morph/ # Core library
│ ├── __init__.py # Public API exports
│ ├── configs.py # Model configs + convenience constructors
│ ├── model.py # ViT3DRegression (top-level model)
│ ├── patch_embedding.py # HybridPatchEmbedding3D
│ ├── conv_operator.py # ConvOperator (3D conv feature extractor)
│ ├── patchify.py # 3D patchification utility
│ ├── cross_attention.py # FieldCrossAttention (multi-field aggregation)
│ ├── positional_encoding.py # Learned positional encodings (linear / bilinear)
│ ├── encoder_block.py # EncoderBlock (LayerNorm + attention + MLP)
│ ├── axial_attention.py # AxialAttention3DSpaceTime (t, d, h, w axes)
│ ├── attention.py # SDPA, LoRALinear, LoRAMHA
│ ├── decoder.py # SimpleDecoder (LayerNorm + Dense + unpatchify)
│ └── convert_weights.py # PyTorch → Flax parameter mapping
├── scripts/
│ ├── convert.py # CLI: convert .pth → .msgpack
│ └── compare.py # CLI: PyTorch vs JAX equivalence test
├── pyproject.toml
├── LICENSE
├── .gitignore
└── README.md
Module Details
Model (model.py)
ViT3DRegression — Top-level model composing patch embedding, positional encoding, N transformer blocks, and a decoder. Handles input reshaping (B, t, F, C, D, H, W) → patch tokens → output prediction. Returns (encoder_output, transformer_output, prediction).
Patch Embedding (patch_embedding.py)
HybridPatchEmbedding3D — Conv3D feature extraction via ConvOperator, 3D patchification, linear projection, and learned field cross-attention via FieldCrossAttention. Aggregates multiple input fields into a single token sequence.
Conv Operator (conv_operator.py)
ConvOperator — 1×1×1 input projection followed by a stack of 3×3×3 convolutions with channel doubling and LeakyReLU activations. Uses channels-last layout.
Cross Attention (cross_attention.py)
FieldCrossAttention — Learned query tokens attend to patch tokens from multiple fields via multi-head attention. Uses DenseGeneral for Q/K/V projections matching PyTorch's nn.MultiheadAttention.
Positional Encoding (positional_encoding.py)
Two variants of learned positional embeddings:
PositionalEncodingSLinTSlice— Time-slicing + 1D linear interpolation over patches. Used for Ti/S/M variants (max_ar ≤ 1).PositionalEncodingSTBilinear— 2D bilinear interpolation with antialiasing over (time, patches). Used for L variant (max_ar > 1).
Custom interpolation functions replicate PyTorch's F.interpolate(align_corners=False, antialias=True) coordinate mapping exactly.
Encoder Block (encoder_block.py)
EncoderBlock — Pre-norm transformer block: LayerNorm → AxialAttention → residual → LayerNorm → MLP → residual. Uses exact GELU (approximate=False) and LayerNorm epsilon=1e-5 to match PyTorch defaults.
Axial Attention (axial_attention.py)
AxialAttention3DSpaceTime — Sequential attention over four axes: time (t), depth (d), height (h), width (w). Each axis uses LoRAMHA (multi-head attention with optional LoRA adapters).
Attention (attention.py)
scaled_dot_product_attention— Manual SDPA (compatible with any JAX version).LoRALinear— Dense layer with optional low-rank adaptation (A,Bmatrices).LoRAMHA— Multi-head attention with separate LoRA-enabled Q/K/V/O projections.
Decoder (decoder.py)
SimpleDecoder — LayerNorm → Dense projection → reshape back to volumetric output. Applies to only the last timestep.
Convert Weights (convert_weights.py)
load_pytorch_state_dict— Loads MORPH.pthcheckpoints (handlesmodel_state_dictwrapper andmodule.prefix from DataParallel).convert_pytorch_to_jax_params— Full parameter conversion with shape validation.
Implementation Notes
Key Differences from PyTorch
-
Channels-last convolutions: JAX Conv3D uses
(D, H, W, C_in, C_out)kernel layout vs PyTorch's(C_out, C_in, D, H, W). -
LayerNorm epsilon: Flax defaults to
1e-6; PyTorch defaults to1e-5. We explicitly setepsilon=1e-5everywhere. -
GELU activation: Flax's
nn.geludefaults to the approximate version; PyTorch uses exact. We usenn.gelu(x, approximate=False). -
Positional encoding interpolation:
jax.image.resizeuses different coordinate conventions than PyTorch'sF.interpolate(align_corners=False). Custom interpolation functions replicate the half-pixel coordinate mapping(i + 0.5) * in_size / out_size - 0.5. -
Antialias bilinear interpolation: For the L model, PyTorch's
antialias=Truezeros out-of-bounds kernel weights before renormalisation. Our implementation matches this exactly via separable 1D passes with a widened triangle kernel. -
LoRA dormant parameters: When
lora_rank=0, LoRA A/B matrices are initialised but never contribute to the output. This matches PyTorch's dormant-LoRA pattern for future fine-tuning. -
attn_t always instantiated: Even when
t=1, the temporal attention module is created (but its residual is not added) to ensure consistent parameter trees for weight loading.
License
BSD-3-Clause — see LICENSE.