Skip to content

Model Examples

This page provides concise constructor and forward-pass examples for the Equinox-facing Foundax API.

Core Model Examples

These are direct architecture constructors from foundax.nn.

import jax
import jax.numpy as jnp
import foundax as fx

# MLP
mlp = fx.mlp(in_features=2, output_dim=1, hidden_dims=64, num_layers=3)
x = jnp.ones((16, 2))
y = mlp(x)

# FNO2D
fno = fx.fno2d(in_features=1, hidden_channels=32, n_modes=16)
grid = jnp.ones((64, 64, 1))
out = fno(grid)

# UNet2D
unet = fx.unet2d(in_channels=1, out_channels=1)
img = jnp.ones((128, 128, 1))
out = unet(img)

Foundation Wrapper Examples

These are namespace-style entry points for larger wrapper families.

Poseidon

import foundax as fx
import jax.numpy as jnp

model = fx.poseidon.T()
x = jnp.ones((1, 128, 128, 4))
t = jnp.array([0.5])
out = model(pixel_values=x, time=t).output

MORPH

import foundax as fx
import jax.numpy as jnp

model = fx.morph.S()
# Example shape: (batch, time, fields, channels, depth, height, width)
x = jnp.ones((1, 2, 1, 1, 16, 16, 16))
out = model(x)

MPP

import foundax as fx
import jax.numpy as jnp

model = fx.mpp.Ti(n_states=3)
# (time, batch, channels, height, width)
x = jnp.ones((2, 1, 3, 64, 64))
state_labels = jnp.array([0, 1, 2])
bcs = jnp.zeros((1, 2), dtype=jnp.int32)
out = model(x, state_labels, bcs)

Walrus

import foundax as fx
import jax.numpy as jnp

model = fx.walrus.base()
# (batch, time, height, width, channels)
x = jnp.ones((1, 2, 64, 64, 4))
state_labels = jnp.array([0, 1, 2, 3])
bcs = [[0, 0], [0, 0]]
out = model(x, state_labels, bcs)

BCAT

import foundax as fx
import jax.numpy as jnp

model = fx.bcat.base()
# (batch, t_in + t_out, 128, 128, channels)
x = jnp.ones((1, 6, 128, 128, 2))
t = jnp.linspace(0.0, 1.0, 6).reshape(1, 6, 1)
out = model(x, t, input_len=4)

PDEformer-2

import foundax as fx
import jax
import jax.numpy as jnp

model = fx.pdeformer2.small()
key = jax.random.PRNGKey(0)

# Minimal synthetic graph-like inputs
n_graph, n_node, n_points = 1, 6, 16
node_type = jnp.ones((n_graph, n_node, 1), dtype=jnp.int32)
node_scalar = jnp.ones((n_graph, 4, 1))
node_function = jnp.ones((n_graph, 2, 128 * 128, 5))
in_degree = jnp.zeros((n_graph, n_node), dtype=jnp.int32)
out_degree = jnp.zeros((n_graph, n_node), dtype=jnp.int32)
attn_bias = jnp.zeros((n_graph, n_node, n_node))
spatial_pos = jnp.zeros((n_graph, n_node, n_node), dtype=jnp.int32)
coordinate = jax.random.uniform(key, (n_graph, n_points, 4))

out = model(
    node_type,
    node_scalar,
    node_function,
    in_degree,
    out_degree,
    attn_bias,
    spatial_pos,
    coordinate,
)

DPOT

import foundax as fx
import jax.numpy as jnp

model = fx.dpot.Ti()
# (batch, height, width, time, channels)
x = jnp.ones((1, 128, 128, 3, 2))
pred, cls = model(x)

PROSE

import foundax as fx
import jax.numpy as jnp

# fd_1to1 returns (model, variables)
model, variables = fx.prose.fd_1to1()

x = jnp.ones((1, 2, 128, 128, 2))
t_in = jnp.zeros((1, 2, 1))
t_out = jnp.ones((1, 2, 1))

pred = model.apply(variables, x, t_in, t_out, deterministic=True)

Notes

  • Input signatures differ between families; check the constructor docstring in the wrapper module when integrating a model.
  • For production training code, pair these constructors with your optimizer and training loop directly in JAX/Equinox.
  • For family-level selection guidance, see Core Models and Foundation Models.