Skip to content

Laplace 1D

This is the smallest complete jNO example. It solves a 1D Poisson or Laplace-type equation with homogeneous Dirichlet boundary conditions and compares the learned solution against the analytical one.

Problem Setup

We solve

-u''(x) = sin(pi x),   x in [0, 1],   u(0) = u(1) = 0

with exact solution

u(x) = sin(pi x) / pi^2

Step 1: Create the Domain

The script initializes jNO, creates a 1D line domain, and extracts interior points.

π = jnn.pi
dire = jno.setup(__file__)

domain = jno.domain(constructor=jno.domain.line(mesh_size=pick(0.01, 0.1)))
x, _ = domain.variable("interior")

What this does:

  • jno.setup(__file__) creates the run directory for outputs.
  • jno.domain.line(...) builds the 1D geometry.
  • domain.variable("interior") gives the collocation points used for the PDE residual.

Step 2: Define the Analytical Reference

The exact solution is only used for monitoring error, not for training supervision.

u_exact = jnn.sin(π * x) / π**2

This is useful because you can track whether the PINN is converging toward the known solution.

Step 3: Build the Network with Hard Boundary Conditions

The model is a small MLP. Boundary conditions are enforced by multiplying the network output by x(1-x).

u_net = jnn.nn.mlp(
    in_features=1,
    hidden_dims=32,
    num_layers=3,
    key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, pick(5_000, 10), 1e-5))

u = u_net(x) * x * (1 - x)

Why this matters:

  • u_net(x) is unconstrained.
  • Multiplying by x(1-x) forces u(0)=u(1)=0 exactly.
  • This removes the need for a separate boundary loss term.

Step 4: Build the PDE Residual and Error Tracker

The residual uses automatic differentiation twice to compute the second derivative.

pde = -jnn.grad(jnn.grad(u, x), x) - jnn.sin(π * x)
error = jnn.tracker((u - u_exact).mse, interval=pick(100, 1))

This gives you:

  • pde.mse: the physics loss to minimize
  • error: a tracked metric that reports the solution error during training

Step 5: Solve the Problem

crux = jno.core([pde.mse, error], domain)
history = crux.solve(pick(5_000, 10))

This is the standard jNO flow:

  1. Bundle constraints and tracked metrics into jno.core(...)
  2. Call solve(...)
  3. Use the returned history for diagnostics

Step 6: Evaluate and Plot

After training, the script sorts points, evaluates the learned field, and saves both the training history and solution plot.

pts = np.array(crux.domain_data.context["interior"][0, 0, :, 0])
idx = np.argsort(pts)
xs = pts[idx]
pred = np.array(crux.eval(u)).reshape(xs.shape[0], 1)[:, 0][idx]
true = np.array(crux.eval(u_exact)).reshape(xs.shape[0], 1)[:, 0][idx]

You end up with:

  • training_history.png
  • solution.png

What To Notice

  • This example uses hard constraints, which keeps the loss simple.
  • The exact solution is not part of the PDE loss, only the tracker.
  • For many introductory PDEs, this is the cleanest pattern to start from.

Script Snippet

"""01 — 1-D Laplace / Poisson equation  (simplest possible PINN)

Problem
-------
    −u''(x) = sin(πx),   x ∈ [0, 1],   u(0) = u(1) = 0

Analytical solution
-------------------
    u(x) = sin(πx) / π²

Techniques shown
----------------
* Homogeneous Dirichlet BCs via hard constraint:  u = net(x) · x (1−x)
* Automatic-differentiation gradient  (jnn.grad)
* Tracker to log the L²-error against the exact solution
* Single-phase Adam with exponential LR decay
"""

import jax
import jno
import jno.jnp_ops as jnn
import optax
from jno import LearningRateSchedule as lrs
import matplotlib.pyplot as plt
import numpy as np

π = jnn.pi
dire = jno.setup(__file__)

# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(constructor=jno.domain.line(mesh_size=0.01))
x, _ = domain.variable("interior")

# ── Analytical solution ───────────────────────────────────────────────────────
u_exact = jnn.sin(π * x) / π**2

# ── Network with hard-enforced BCs ────────────────────────────────────────────
u_net = jno.nn.mlp(
    in_features=1,
    hidden_dims=32,
    num_layers=3,
    key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 5_000, 1e-5))

u = u_net(x) * x * (1 - x)  # hard BC: u(0) = u(1) = 0

# ── Constraints ───────────────────────────────────────────────────────────────
pde = -jnn.grad(jnn.grad(u, x), x) - jnn.sin(π * x)  # should be 0
error = jnn.tracker((u - u_exact).mse, interval=100)

# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, error], domain)
history = crux.solve(5_000)
history.plot(f"{dire}/training_history.png")

# ── Plot ──────────────────────────────────────────────────────────────────────
pts = np.array(crux.domain_data.context["interior"][0, 0, :, 0])
idx = np.argsort(pts)
xs = pts[idx]
pred = np.array(crux.eval(u)).reshape(xs.shape[0], 1)[:, 0][idx]
true = np.array(crux.eval(u_exact)).reshape(xs.shape[0], 1)[:, 0][idx]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.set_title("Solution")
ax1.plot(xs, pred, label="PINN")
ax1.plot(xs, true, "--", label="exact")
ax1.set_xlabel("x")
ax1.legend()

ax2.set_title("Pointwise |error|")
ax2.plot(xs, np.abs(pred - true))
ax2.set_xlabel("x")

plt.tight_layout()
plt.savefig(f"{dire}/solution.png", dpi=150)
print(f"Saved to {dire}/")