Skip to content

Poisson 1D

This example solves nearly the same equation as Laplace 1D, but changes two important implementation choices: it uses soft boundary constraints and a finite-difference second derivative.

Problem Setup

We solve

-u''(x) = sin(pi x),   x in [0, 1],   u(0) = u(1) = 0

with exact solution

u(x) = sin(pi x) / pi^2

Step 1: Create Interior and Boundary Variables

Unlike the hard-constraint example, this script explicitly asks for both interior and boundary points.

domain = jno.domain(constructor=jno.domain.line(mesh_size=pick(0.01, 0.1)))
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")

Why both are needed:

  • x is used for the PDE residual.
  • xb is used to define the boundary-condition loss.

Step 2: Define the Reference Solution

u_exact = jnn.sin(π * x) / π**2

As in the previous example, this is only used to track model quality.

Step 3: Build the Network

This version keeps the network output unconstrained.

u_net = jnn.nn.mlp(
    in_features=1,
    hidden_dims=64,
    num_layers=4,
    key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, pick(10_000, 10), 1e-5))

u = u_net(x)

The consequence is that boundary conditions must now be enforced through an explicit loss term.

Step 4: Define PDE and Boundary Losses

pde = -u.d2(x, scheme="finite_difference") - jnn.sin(π * x)
bc = u_net(xb)
error = jnn.tracker((u - u_exact).mse, interval=pick(100, 1))

Key difference from Laplace 1D:

  • u.d2(..., scheme="finite_difference") computes the second derivative numerically.
  • bc = u_net(xb) is minimized toward zero, which softly enforces the boundary values.

Step 5: Solve With Multiple Constraints

crux = jno.core([pde.mse, bc.mse, error], domain)
history = crux.solve(pick(10_000, 10))

Now the optimization balances:

  • PDE residual in the interior
  • boundary mismatch at the endpoints
  • tracked error against the exact solution

Step 6: Evaluate Error and Plot

After solving, the script computes the mean absolute error and saves the solution plot.

mae = np.abs(pred - true).mean()
print(f"Mean absolute error vs exact: {mae:.6e}")

This gives a simple scalar quality check in addition to the saved figures.

What To Notice

  • Soft constraints are more flexible, especially when hard constraints are awkward to encode.
  • Finite differences provide an alternative to fully automatic differentiation.
  • This is a good template when you need explicit control over boundary and residual terms.

Script Snippet

"""01 — 1-D Poisson equation  (soft boundary constraints + FD Laplacian)

Problem
-------
    −u''(x) = sin(πx),   x ∈ [0, 1],   u(0) = u(1) = 0

Analytical solution
-------------------
    u(x) = sin(πx) / π²

Compared to laplace_1d.py this example uses:
* Finite-difference second derivative  (u.d2)
* Soft boundary constraints  (separate boundary tag)
* A tracker that logs the MSE error every 100 epochs
"""

import jax
import jno
import jno.jnp_ops as jnn
import optax
import numpy as np
import matplotlib.pyplot as plt
from jno import LearningRateSchedule as lrs

π = jnn.pi
dire = jno.setup(__file__)

# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(constructor=jno.domain.line(mesh_size=0.01))
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")

# ── Analytical solution ───────────────────────────────────────────────────────
u_exact = jnn.sin(π * x) / π**2

# ── Network ───────────────────────────────────────────────────────────────────
u_net = jno.nn.mlp(
    in_features=1,
    hidden_dims=64,
    num_layers=4,
    key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 10_000, 1e-5))

u = u_net(x)

# ── Constraints ───────────────────────────────────────────────────────────────
pde = -u.d2(x, scheme="finite_difference") - jnn.sin(π * x)
bc = u_net(xb)  # soft: u(0) = u(1) = 0
error = jnn.tracker((u - u_exact).mse, interval=100)

# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc.mse, error], domain)
history = crux.solve(10_000)
history.plot(f"{dire}/training_history.png")

# ── Plot ──────────────────────────────────────────────────────────────────────
pts = np.array(crux.domain_data.context["interior"][0, 0, :, 0])
idx = np.argsort(pts)
xs = pts[idx]
pred = np.array(crux.eval(u)).reshape(xs.shape[0], 1)[:, 0][idx]
true = np.array(crux.eval(u_exact)).reshape(xs.shape[0], 1)[:, 0][idx]

mae = np.abs(pred - true).mean()
print(f"Mean absolute error vs exact: {mae:.6e}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.set_title("Solution")
ax1.plot(xs, pred, label="PINN")
ax1.plot(xs, true, "--", label="exact")
ax1.set_xlabel("x")
ax1.legend()

ax2.set_title("Pointwise |error|")
ax2.plot(xs, np.abs(pred - true))
ax2.set_xlabel("x")

plt.tight_layout()
plt.savefig(f"{dire}/solution.png", dpi=150)
print(f"Saved to {dire}/")