HyCo: Hybrid-Cooperative PINN
This tutorial implements the Hybrid-Cooperative (HyCo) learning framework from Liverani, Steynberg & Zuazua (2025) using jno.fn.stop_gradient.
The Idea
A standard PINN must balance two potentially conflicting objectives: satisfying the PDE and fitting observations. HyCo instead trains two specialised networks in parallel:
| Model | Objective |
|---|---|
u_phy — physical model |
Enforce the PDE residual |
u_syn — synthetic model |
Fit the sparse, noisy observations |
The two models are kept in sync through a mutual interaction loss evaluated at the interior collocation points. jno.fn.stop_gradient is the key ingredient: it ensures each interaction term only updates the student model's parameters, leaving the reference model's weights frozen for that gradient step.
Loss Decomposition
where
Here \(\operatorname{sg}(\cdot)\) denotes stop_gradient. All four terms live in the same jno.core call — stop-gradient does the work of keeping the gradient paths separate.
Problem
1D Poisson on \([0, 1]\):
Observations: 7 randomly placed sensors with additive Gaussian noise (\(\sigma = 0.05\)).
Setup
Domains
A standard line domain provides the collocation points. Sensor coordinates are registered on the same domain object by passing the array directly to variable:
# Dense collocation grid — PDE residual + interaction
domain = jno.domain.line(mesh_size=0.02)
x, _ = domain.variable("interior")
# Sparse noisy sensors — registered on the same domain
x_sen = np.linspace(0.1, 0.9, 7).reshape(-1, 1)
u_sen = np.sin(np.pi * x_sen) + rng.normal(0, 0.05, x_sen.shape)
(x_s,) = domain.variable("sensors", sample=x_sen, split=True, point_data=True)
Networks
u_phy_net = jno.nn.wrap(foundax.mlp(in_features=1, output_dim=1, hidden_dims=32, num_layers=3, key=k1))
u_syn_net = jno.nn.wrap(foundax.mlp(in_features=1, output_dim=1, hidden_dims=32, num_layers=3, key=k2))
for net in [u_phy_net, u_syn_net]:
net.optimizer(optax.adam(1e-3))
u_phy = u_phy_net(x) * x * (1 - x) # hard zero BCs
u_syn = u_syn_net(x) * x * (1 - x)
The Four Loss Terms
# Physical model: PDE residual
L_pde = (u_phy.dd(x) + π**2 * jno.np.sin(π * x)).mse
# Synthetic model: fit the noisy observations
u_syn_s = u_syn_net(x_s) * x_s * (1 - x_s)
u_obs = jno.np.array(u_sen)
L_data = (u_syn_s - u_obs).mse
# Mutual alignment — stop_gradient prevents cross-model gradient flow
L_int_phy = (u_phy - jno.fn.stop_gradient(u_syn)).mse
L_int_syn = (u_syn - jno.fn.stop_gradient(u_phy)).mse
Why stop_gradient?
Without it, L_int_phy would backpropagate through both u_phy and u_syn, turning the interaction into a confusing cross-model gradient signal. With stop_gradient:
- Gradients of
L_int_phyreach onlyu_phy_net— it is nudged towardu_syn's predictions. - Gradients of
L_int_synreach onlyu_syn_net— it is nudged towardu_phy's predictions.
Alternating updates via substeps
HyCo prescribes alternating updates — u_phy is updated first, then u_syn is updated using the freshly updated u_phy. Updating both simultaneously in a single optimizer step lets each model only see a stale snapshot of the other, which slows convergence and contaminates Adam momentum.
The substeps argument to solve() expresses the alternating schedule:
α, β = 1.0, 1.0
crux = jno.core(
[L_pde, β * L_int_phy, α * L_data, β * L_int_syn],
)
# Each outer epoch runs two gradient steps in sequence:
# substep 0: constraints [0, 1] → only u_phy_net updates
# substep 1: constraints [2, 3] → only u_syn_net updates (sees fresh u_phy)
crux.solve(3_000, substeps=[[0, 1], [2, 3]])
Each substep keeps its own optimizer state, so Adam momentum for u_phy accumulates only from substep 0's gradients and never decays from being held inactive during substep 1. The shared trainable dict carries parameter updates between substeps, so substep 1 sees the freshly written u_phy weights.
You can also run multiple gradient steps per substep before alternating:
crux.solve(1_500, substeps=[([0, 1], 2), ([2, 3], 2)])
# 1500 outer epochs × (2 phy + 2 syn) = 6000 effective gradient steps
The shorthand [0, 1] is equivalent to ([0, 1], 1). Within a single substep the n repeated steps share the same optimizer state so Adam momentum builds up continuously before the switch.
Results
The physics model, guided by both the PDE and alignment with the data-fitted synthetic model, reaches near-exact accuracy. The synthetic model, guided by 7 noisy observations and alignment with the physics model, settles on a smooth physically consistent solution — far better than overfitting to the raw data alone.
What To Notice
jno.fn.stop_gradientis the single syntactic addition that turns a standard two-model PINN into a cooperative system.substeps=[[0, 1], [2, 3]]expresses the alternating schedule HyCo requires — no manual outer loop needed.jno.domain.from_arraywith multiple tags keeps collocation and sensor points in the same domain object.- Tune
αandβto control how strongly each model is pulled toward the other.
Going Further
- Replace the dense collocation grid with adaptive resampling — see Adaptive Resampling.
- Use the synthetic model as a warm-start for a harder PDE where the PINN struggles to find the solution from scratch.
- Extend to 2D by replacing the 1D collocation grid with a 2D rect domain and the BC factor with \(x(1-x)y(1-y)\).
- See Liverani et al. (2025) for analysis of the Gray-Scott and Helmholtz cases.
Script Snippet
"""05 — HyCo: Hybrid-Cooperative Learning for PINNs (1-D Poisson)"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import jno
π = jno.np.pi
rng = np.random.default_rng(0)
domain = jno.domain.line(mesh_size=0.02)
x, _ = domain.variable("interior")
# Sparse noisy sensors, registered as a named point set on the same domain.
x_sen = np.linspace(0.1, 0.9, 7).reshape(-1, 1)
u_sen = np.sin(np.pi * x_sen) + rng.normal(0, 0.05, x_sen.shape)
(x_s,) = domain.variable("sensors", sample=x_sen, split=True, point_data=True)
# Two networks, one optimiser each.
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
u_phy_net = jno.nn.wrap(foundax.mlp(in_features=1, output_dim=1, hidden_dims=32, num_layers=3, key=k1))
u_syn_net = jno.nn.wrap(foundax.mlp(in_features=1, output_dim=1, hidden_dims=32, num_layers=3, key=k2))
for net in (u_phy_net, u_syn_net):
net.optimizer(optax.adam(1e-3))
# Hard Dirichlet ansatz; .bind(x=x) so the PDE residual reads as u.xx.
u_phy = (u_phy_net(x) * x * (1 - x)).scalar.bind(x=x)
u_syn = (u_syn_net(x) * x * (1 - x)).scalar.bind(x=x)
# Loss components
L_pde = (u_phy.xx + π**2 * jno.np.sin(π * x)).mse
u_syn_at_sensors = u_syn_net(x_s) * x_s * (1 - x_s)
L_data = (u_syn_at_sensors - jno.np.array(u_sen)).mse
L_int_phy = (u_phy - u_syn.stop_gradient).mse
L_int_syn = (u_syn - u_phy.stop_gradient).mse
α, β = 1.0, 1.0
crux = jno.core([L_pde, β * L_int_phy, α * L_data, β * L_int_syn])
# Alternating updates: first u_phy (constraints 0, 1), then u_syn (constraints 2, 3).
crux.solve(3_000, substeps=[[0, 1], [2, 3]])
u_exact_expr = jno.np.sin(π * x)
_u_phy, _u_syn, _u_exact = crux.eval([u_phy, u_syn, u_exact_expr])
rel_phy = float(jnp.linalg.norm(_u_phy - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
rel_syn = float(jnp.linalg.norm(_u_syn - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
print(f"u_phy rel-L2 : {rel_phy:.3e} u_syn rel-L2 : {rel_syn:.3e}")
assert rel_phy < 0.05, f"Physical model error too large: {rel_phy:.3e}"
assert rel_syn < 0.10, f"Synthetic model error too large: {rel_syn:.3e}"
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(
f"05_coupled_and_inverse/hyco_poisson_1d.py | epochs=3000 | alpha={α} | beta={β}"
f" | rel_L2_phy={rel_phy:.6e} | rel_L2_syn={rel_syn:.6e}\n"
)