Neural Network Architectures
All neural network architectures in jNO are accessible via jno.numpy.nn (importable as jnn.nn). Every factory method returns a Model wrapper that integrates seamlessly with the jNO training pipeline — supporting .optimizer(), .freeze(), .lora(), .mask(), and .initialize().
Architecture Overview
| Method | Type | Input | Best For |
|---|---|---|---|
nn.mlp |
MLP | pointwise coords | baselines, inverse problems |
nn.fno1d/2d/3d |
FNO | grid (structured) | periodic BCs, smooth solutions |
nn.cno2d |
CNO | grid (structured) | resolution-invariant operators |
nn.unet1d/2d/3d |
U-Net | grid (structured) | multi-scale, encoder-decoder |
nn.mgno1d/2d |
MgNO | grid | multi-resolution operator learning |
nn.geofno/geofno2d/3d |
GeoFNO | irregular grid | non-rectangular domains |
nn.pcno |
PCNO | point cloud | unstructured / scattered data |
nn.deeponet |
DeepONet | branch (function) + trunk (coords) | small data, operator learning |
nn.pointnet |
PointNet | point cloud | permutation-invariant processing |
nn.transformer |
Transformer | sequence | general-purpose attention |
nn.pit |
PiT | coords + features | distance-based attention |
nn.scot |
ScOT | coords + features | scalable operator transformer |
nn.gnot/cgptno/moegnot |
GNOT | mixed | general neural operator |
nn.poseidonT/B/L |
Poseidon | 128×128 grid | pretrained foundation model |
nn.wrap |
any Equinox/Flax | user-defined | custom architectures |
Common Model API
Every model returned by a factory method supports:
model.optimizer(opt_fn, lr=schedule) # attach optimizer (required before solve)
model.freeze() # freeze all parameters
model.mask(param_mask) # freeze/unfreeze individual parameters
model.lora(rank=4, alpha=1.0) # LoRA fine-tuning
model.initialize(path_or_pytree) # load pretrained weights
model.dont_show() # suppress model summary during compile
Methods return self for chaining:
MLP (Multi-Layer Perceptron)
A fully-connected feedforward network for point-wise mappings.
u_net = jnn.nn.mlp(
in_features=2, # number of input features (last axis)
output_dim=1, # number of output features
activation=jnp.tanh, # hidden activation: relu, gelu, selu, sin, ...
hidden_dims=64, # int or list of ints per layer
num_layers=3, # number of hidden layers
output_activation=None, # optional output activation
use_bias=True,
dropout_rate=0.0,
layer_norm=False,
batch_norm=False,
key=jax.random.PRNGKey(0),
)
Usage:
u = u_net(x, y) # inputs are concatenated along the feature axis
u = u_net(x, y, k) # additional parameters are also concatenated
Fourier Neural Operator (FNO)
FNO learns operators in Fourier space. Particularly effective for problems with smooth solutions or periodic boundary conditions.
FNO 1D
model = jnn.nn.fno1d(
in_features=1,
hidden_channels=64,
n_modes=16, # number of Fourier modes to retain
d_vars=1, # output channels
n_layers=4,
n_steps=1, # output time steps (autoregressive rollout)
activation=jax.nn.gelu,
linear_conv=True, # False for periodic BCs
key=key,
)
FNO 2D
model = jnn.nn.fno2d(
in_features=1,
hidden_channels=32,
n_modes=12,
d_vars=1,
n_layers=4,
d_model=(64, 64), # spatial resolution (for positional encoding)
use_positions=False, # prepend coordinate grid to input
key=key,
)
FNO 3D
model = jnn.nn.fno3d(
in_features=1,
hidden_channels=24,
n_modes=8,
d_vars=1,
d_model=(32, 32, 32),
key=key,
)
Continuous Neural Operator (CNO 2D)
U-Net style architecture with bicubic-interpolation continuous activations for resolution-invariant learning.
model = jnn.nn.cno2d(
in_dim=1,
out_dim=1,
size=64, # spatial size (must be divisible by 2^N_layers)
N_layers=3,
N_res=4, # residual blocks per encoder level
N_res_neck=4, # residual blocks in bottleneck
channel_multiplier=16,
use_bn=True,
key=key,
)
U-Net
Encoder-decoder with skip connections. Suitable for general PDE operator learning on regular grids.
model = jnn.nn.unet1d(in_dim=1, out_dim=1, hidden_channels=32, n_levels=3, key=key)
model = jnn.nn.unet2d(in_dim=1, out_dim=1, hidden_channels=32, n_levels=3, key=key)
model = jnn.nn.unet3d(in_dim=1, out_dim=1, hidden_channels=16, n_levels=3, key=key)
Multigrid Neural Operator (MgNO)
Inspired by algebraic multigrid solvers; efficient for problems with multi-scale features.
model = jnn.nn.mgno1d(in_dim=1, out_dim=1, key=key)
model = jnn.nn.mgno2d(in_dim=1, out_dim=1, key=key)
Geometry-Informed FNO (GeoFNO)
FNO adapted for irregular / non-rectangular domains via learned geometric mappings.
from jno.architectures.geofno import compute_Fourier_modes
model = jnn.nn.geofno2d(
nks=(16, 16), # Fourier modes per dimension
Ls=(1.0, 1.0), # domain physical extents
in_dim=3, # input channels
out_dim=1,
key=key,
)
Point Cloud Neural Operator (PCNO)
Handles unstructured / scattered point clouds using graph convolutions.
DeepONet
Separate branch network (encodes input function) and trunk network (encodes query coordinates). Especially effective in the small-data regime.
model = jnn.nn.deeponet(
n_sensors=10, # number of sensor points for branch input
sensor_channels=1, # channels per sensor measurement
coord_dim=2, # spatial dimension for trunk network
basis_functions=64, # latent space dimension
hidden_dim=64,
n_layers=3,
key=key,
)
# Usage: u(sensor_data, spatial_coords)
u = model(t, jnn.concat([x, y]))
PointNet
Permutation-invariant processing of unordered point clouds.
Transformer
Standard multi-head self-attention transformer.
Position-Induced Transformer (PiT)
Attention mechanism where weights depend on pairwise distances between query points. Well-suited for irregular meshes.
Scalable Operator Transformer (ScOT)
Swin-Transformer-based architecture; efficient attention for large resolution inputs.
GNOT (General Neural Operator Transformer)
model = jnn.nn.gnot(...)
model = jnn.nn.cgptno(...) # Conditional GNOT
model = jnn.nn.moegnot(...) # Mixture-of-Experts variant
Poseidon Foundation Models
Pretrained ScOT models trained on a broad distribution of fluid dynamics PDEs. Suitable for fine-tuning with minimal data.
model = jnn.nn.poseidonT(key=key) # Tiny (~4 M params)
model = jnn.nn.poseidonB(key=key) # Base (~25 M params)
model = jnn.nn.poseidonL(key=key) # Large (~80 M params)
Expected input: structured 128×128 grids (use jno.domain.poseidon()).
Wrapping Custom Models
Equinox Module
import equinox as eqx
import jno.numpy as jnn
class MyNet(eqx.Module):
fc: eqx.nn.Linear
def __init__(self, *, key):
self.fc = eqx.nn.Linear(2, 1, key=key)
def __call__(self, x, y):
return self.fc(jnp.concat([x, y], axis=-1))
my_net = jnn.nn.wrap(MyNet(key=jax.random.PRNGKey(0)))
u = my_net(x, y) # returns Placeholder for use in constraints
Flax Linen Module
from flax import linen as nn
class FlaxMLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(64)(x)
x = nn.tanh(x)
return nn.Dense(1)(x)
my_net = jnn.nn.flaxwrap(FlaxMLP(), input=(dummy_x,), key=key)
Flax NNX Module (auto-detected)
from flax import nnx
class NNXModel(nnx.Module):
def __init__(self, rngs):
self.dense = nnx.Linear(2, 1, rngs=rngs)
def __call__(self, x):
return self.dense(x)
my_net = jnn.nn.wrap(NNXModel(nnx.Rngs(0))) # auto-detected as NNX
Architecture Search Wrapper
When using the hyperparameter tuner, pass a class (not instance) plus an ArchSpace:
a_space = jnn.tune.space()
a_space.unique("act", [jnp.tanh, jax.nn.gelu, jnp.sin], category="architecture")
a_space.unique("hid", [32, 64, 128], category="architecture")
a_space.unique("dep", [2, 3], category="architecture")
class TunableMLP(eqx.Module):
def __init__(self, arch: jnn.tune.Arch, *, key):
hidden = arch("hid")
depth = arch("dep")
self.act = arch("act")
# build layers...
u = jnn.nn.wrap(TunableMLP, space=a_space)(x, y)
Trainable Parameters (Inverse Problems)
To learn scalar or tensor PDE coefficients during training: