Skip to content

API Reference

This page is auto-generated from in-source docstrings via mkdocstrings. When something here looks wrong, fix the docstring — not this file.

For conceptual prose (what these objects are and why they exist), see the Getting Started guide and the Concepts page.


jno.core

The top-level solver. Wraps a list of constraint expressions and a domain, compiles them once, then exposes solve() / eval().

jno.core.core

core(constraints: List[Placeholder], mesh: Optional[Tuple[int, ...]] = (1, 1), resume_from: Optional[str] = None, *, domain: Optional[domain] = None)

core solver using traced operations.

Initialize core solver.

The random seed is read from config — JNO_SEED (env), else [jno] seed in .jno.toml / ~/.jno/config.toml, else 42 (via jno.get_seed); it is not a constructor argument.

PARAMETER DESCRIPTION
constraints

List of constraint expressions defining the problem to solve. Each constraint represents an equation or condition that should be minimized during training (e.g., PDE residuals, boundary conditions, data fitting terms).

TYPE: List[Placeholder]

domain

Optional domain override. When omitted (None), the domain is auto-discovered by walking constraints and collecting the unique Variable._domain reference. Pass domain= explicitly when the constraint tree contains no standard Variable nodes — e.g. FEM/VPINN weak-form assemblies or pure-parametric inverse losses built from jno.domain.from_array.

TYPE: Optional[domain] DEFAULT: None

mesh

Shape of the device mesh for hybrid parallelism as a tuple (batch, model). Controls how computation is distributed across multiple GPUs/TPUs.

  • First dimension (batch): Number of devices for data parallelism. Data is split across these devices, each processes different samples. Parameters are replicated on all devices.

  • Second dimension (model): Number of devices for model parallelism. Model parameters are sharded across these devices. Use when model is too large to fit on a single device.

Examples: - (1, 1): No parallelism, single device (default) - (2, 1): Pure data parallelism on 2 GPUs - 2x throughput - (1, 2): Pure model parallelism on 2 GPUs - fit 2x larger models - (4, 1): Data parallelism on 4 GPUs - 4x throughput - (2, 2): Hybrid parallelism on 4 GPUs - 2x data, 2x model - (4, 2): Hybrid parallelism on 8 GPUs - 4x data, 2x model

Note: batch * model must equal the total number of available devices.

Recommendations: - Model fits on 1 GPU: Use (n_devices, 1) for maximum throughput - Model doesn't fit on 1 GPU: Use (1, n_devices) for model sharding - Large model + large data: Use hybrid, e.g., (2, 2) on 4 GPUs

Default: (1, 1), automatically expanded to (n_devices, 1) for pure data parallelism when multiple devices are available.

TYPE: Optional[Tuple[int, ...]] DEFAULT: (1, 1)

resume_from

Path to a checkpoint directory written by :class:~jno.utils.callbacks.CheckpointCallback. When provided, model parameters, optimizer states, and the RNG key are restored from the latest checkpoint at the start of the next solve() call. Requires the optional orbax-checkpoint package.

TYPE: Optional[str] DEFAULT: None

compile

compile(mesh: Optional[Tuple[int, ...]] = (1, 1))

solve

solve(epochs: int = 1000, batchsize: Optional[int] = None, checkpoint_gradients: bool = False, offload_data: bool = False, inner_steps: int = 1, accumulation_steps: int = 1, min_consecutive: Optional[int] = 1, profile: bool = False, callbacks: Optional[List] = None, substeps: list | None = None) -> statistics

Train using per-model optimizers attached via model.optimizer().

Every model used in the constraints must have an optimizer attached before calling solve(). Models can optionally be frozen (model.freeze()) or have LoRA enabled (model.lora(rank, alpha)).

PARAMETER DESCRIPTION
epochs

Number of training epochs.

TYPE: int DEFAULT: 1000

batchsize

Mini-batch size (None for full-batch).

TYPE: Optional[int] DEFAULT: None

checkpoint_gradients

If True, wrap each constraint's forward pass in jax.checkpoint (gradient checkpointing / activation rematerialisation). Trades ~30 % extra compute for significantly lower activation memory. Default False.

TYPE: bool DEFAULT: False

offload_data

If True, keep the full training dataset in host (CPU) memory and stream only the current mini-batch to the device each step. Requires batchsize to be set. Default False.

TYPE: bool DEFAULT: False

inner_steps

Number of gradient steps to fuse into a single jax.lax.fori_loop call, amortising Python dispatch overhead. Must evenly divide epochs. Default 1.

TYPE: int DEFAULT: 1

accumulation_steps

Number of micro-batches whose gradients are averaged before a single optimizer update. The effective batch size becomes batchsize * accumulation_steps while peak activation memory stays proportional to batchsize. Requires batchsize to be set. Default 1.

TYPE: int DEFAULT: 1

min_consecutive

Minimum number of consecutive time steps fed to each constraint evaluation. None means use all available time steps. Default 1.

TYPE: Optional[int] DEFAULT: 1

profile

If True, capture a JAX profiler trace for a short window of steady-state training steps. The trace is written to <logger.path>/traces. Default False.

TYPE: bool DEFAULT: False

callbacks

Optional list of :class:~jno.utils.callbacks.Callback instances. on_epoch_end is called after every outer step; on_training_end is called once after the loop finishes.

TYPE: Optional[List] DEFAULT: None

substeps

Optional list of substep specs for alternating optimisation. Each entry is either a plain list of constraint indices [i, j, ...] (1 gradient step) or a tuple ([i, j, ...], n) (n gradient steps sharing the same optimizer state). Each substep runs sequentially per outer epoch and has its own independent optimizer states, so Adam momentum accumulates only for actively trained models.

Example — HyCo alternating::

crux = jno.core(
    [L_pde, beta * L_int_phy, alpha * L_data, beta * L_int_syn],
)
crux.solve(1_500, substeps=[[0, 1], [2, 3]])
# 1500 outer epochs × 2 substeps = 3000 effective gradient steps

TYPE: list | None DEFAULT: None

RETURNS DESCRIPTION
statistics

Training history with .plot() convenience.

TYPE: statistics

eval

eval(operation: Union[List[BinaryOp], BinaryOp], domain: Optional[domain] = None, min_consecutive: Optional[int] = 1, key: Any = None, samples: str = 'auto')

Evaluate an operation (or list of operations) on the current models.

PARAMETER DESCRIPTION
operation

Expression(s) to evaluate.

TYPE: Union[List[BinaryOp], BinaryOp]

domain

Override the stored domain.

TYPE: Optional[domain] DEFAULT: None

min_consecutive

Consecutive-time-step window for time-dependent expressions.

TYPE: Optional[int] DEFAULT: 1

key

Optional PRNG key for stochastic ops.

TYPE: Any DEFAULT: None

samples

How to handle Bayesian models in the dependency graph:

  • "auto" (default) — per expression: if any model it depends on has posterior_samples set, vmap the evaluator over the chain (output shape (n_samples, *original_shape)); otherwise evaluate at the point value.
  • "chain" — force chain evaluation; raises if no Bayesian model appears in any expression's dependency graph.
  • "point" — force point evaluation for every expression (last sample for Bayesian models, trained value for optax models). Use for a quick look without paying the vmap cost.

The default flips to chain automatically because a single last-sample evaluation of a nonlinear function of Bayesian weights is, in general, not a meaningful summary of the posterior (f(mean(θ)) ≠ mean(f(θ))).

TYPE: str DEFAULT: 'auto'

sweep

sweep(space: ArchSpace, optimizer: Union[str, type], budget: int, devices: Union[None, int, str, List[int], DeviceConfig] = None) -> statistics

Run architecture and hyperparameter search with optional parallelism.

PARAMETER DESCRIPTION
space

ArchSpace defining the search space (architecture + training params)

TYPE: ArchSpace

optimizer

Nevergrad optimizer name (e.g., "NGOpt", "OnePlusOne", "CMA"), class, or None for exhaustive grid search

TYPE: Union[str, type]

budget

Number of configurations to try (ignored for grid search)

TYPE: int

devices

Device specification for parallel execution: - None: auto-detect and use all available devices - int: use this many devices - str: device type ("gpu", "cpu", "tpu") - List[int]: specific device indices to use - DeviceConfig: explicit device configuration

TYPE: Union[None, int, str, List[int], DeviceConfig] DEFAULT: None

RETURNS DESCRIPTION
statistics

Training statistics from the best configuration

print_shapes

print_shapes(min_consecutive: Optional[int] = 1)

Print shape-annotated expression trees to stdout.

Can be called any time after compile() or solve() has run. Useful for troubleshooting shape mismatches::

crux = jno.core([pde.mse, ini.mse])
crux.print_shapes()

Domain

jno.domain is the entry point for spatial geometry, mesh management, sampling, and tensor tags.

jno.domain.domain module-attribute

domain = _domain

jno.domain.csg

Lazy Shapely-backed 2D polygon domain with true CSG operators.

The class preserves the jno.domain variable/context contract but does not create a mesh. Point sets are materialized only when variable or sample is called with an explicit sample count.

from_polygons classmethod

from_polygons(polygons: Mapping[str, Sequence[Sequence[float]]], *, time: Optional[Tuple[float, float, int]] = None, compute_mesh_connectivity: bool = False, mesh_size: Optional[float] = None, sampler: Optional[Any] = None, samplers: Optional[Mapping[str, Any]] = None, resampling_strategy: Optional[Any] = None, resampling_strategies: Optional[Mapping[str, Any]] = None) -> 'PolygonDomain'

Create one CSG domain from a mapping of region names to vertices.

from_regions classmethod

from_regions(regions: Mapping[str, BaseGeometry], *, time: Optional[Tuple[float, float, int]] = None, compute_mesh_connectivity: bool = False, mesh_size: Optional[float] = None, sampler: Optional[Any] = None, samplers: Optional[Mapping[str, Any]] = None, resampling_strategy: Optional[Any] = None, resampling_strategies: Optional[Mapping[str, Any]] = None) -> 'PolygonDomain'

Create one CSG domain from named Shapely polygonal regions.

region

region(name: str, where: Any, *, kind: Optional[str] = None) -> 'PolygonDomain'

Define a named sub-region addressable via domain.variable(name) and usable as a FEM boundary-condition location.

Parameters

name: Region tag name (e.g. "inlet", "right_top"). where: * a shapely geometry — a boundary LineString/MultiLineString or an area Polygon/MultiPolygon registered directly; * a point predicate f(x, y) -> bool selecting the part of the active boundary where it holds (evaluated per analytic edge-segment midpoint, so it selects whole polygon edges); * a str aliasing an already-registered tag. kind: "boundary" or "interior". None auto-detects from the geometry (area → interior, line → boundary); predicates are boundary-only for now.

Returns self for chaining.

add_boundary_segments

add_boundary_segments(tag: str, segments: Sequence[Sequence[Sequence[float]]], *, normal_geometry: Optional[Any] = None) -> 'PolygonDomain'

Register an additional boundary tag from explicit line segments.

This is intended for imported boundary-condition/radiation surfaces that are subsets of component boundaries rather than whole closed polygons.

compute_enclosure_view_factor

compute_enclosure_view_factor(tags: Sequence[str], opaque_tags: Optional[Sequence[str]] = None, medium_tags: Optional[Sequence[str]] = None)

Compute cross-tag polygon boundary view factors for radiative BCs.

All tags must be polygon boundary tags that have already been sampled with normals. The method ray-traces line-of-sight against all known polygon boundary segments, then stores one visibility block and one view-factor block for every source/target tag pair:

v_<source>__<target> and f_<source>__<target>.

PARAMETER DESCRIPTION
tags

Boundary tags participating in the radiation enclosure.

TYPE: Sequence[str]

opaque_tags

Accepted for API compatibility. PolygonDomain uses all known polygon boundaries as opaque blockers, so this argument is currently informational.

TYPE: Optional[Sequence[str]] DEFAULT: None

medium_tags

Region names whose union is the radiating medium. Normals are oriented to point into this medium before computing view factors. If omitted and regions named Gas or Air exist, those are used automatically.

TYPE: Optional[Sequence[str]] DEFAULT: None

draw_candidates

draw_candidates(tag: str)

Return (points, normals_or_None) candidate pool for resampling.

Generates fresh candidate points from the polygon geometry on each call (10× the currently-sampled count, min 1000) so that resampling strategies can explore the full domain rather than being confined to the initial sample.

stack classmethod

stack(*batched_domains: 'PolygonDomain', n_interior: int = 256, n_boundary: int = 64) -> 'domain'

Stack multiple PolygonDomains into one batched domain for multi-geometry training.

Use n * dom to set how many independent samplings of a geometry appear in the training batch before passing it here. Points are drawn by rejection sampling so n_interior / n_boundary are exact regardless of mesh size.

PARAMETER DESCRIPTION
*batched_domains

PolygonDomain instances, typically n * dom.

TYPE: 'PolygonDomain' DEFAULT: ()

n_interior

Interior collocation points per geometry sample.

TYPE: int DEFAULT: 256

n_boundary

Boundary points per geometry sample.

TYPE: int DEFAULT: 64

RETURNS DESCRIPTION
'domain'

A jno.domain with per-sample point pools — variable("interior")

'domain'

yields (B_total, 1, n_interior, 1) per coordinate, one independent

'domain'

sampling per geometry instance.

Example::

from shapely.geometry import box, Point
import jno

d1 = jno.domain(box(0, 0, 1, 1))
d2 = jno.domain(Point(0.5, 0.5).buffer(0.5))
dom = jno.domain.stack(100 * d1, 100 * d2, n_interior=512, n_boundary=128)
x, y = dom.variable("interior")   # (200, 1, 512, 1) each
xb, yb = dom.variable("boundary") # (200, 1, 128, 1) each

build_mesh

build_mesh(mesh_size: float = 0.1, *, algorithm: int = 6, region_mesh_sizes: Optional[Mapping[str, float]] = None, sizes: Optional[Mapping[str, float]] = None, interpolate: bool = True) -> 'PolygonDomain'

Generate a gmsh mesh from the active Shapely CSG geometry.

After this call, self.mesh_connectivity and self._boundary_registry are populated and downstream operations that need a mesh (expr.integrate(), scheme="finite_difference" derivatives) become available. The lazy sampling path is untouched: previously materialized collocation samples in self.context survive and automatic- differentiation derivatives keep using them.

PARAMETER DESCRIPTION
mesh_size

Default target element size for points that don't fall on any per-region boundary override.

TYPE: float DEFAULT: 0.1

algorithm

pygmsh algorithm (passed through to generate_mesh).

TYPE: int DEFAULT: 6

region_mesh_sizes

Per-source-region mesh size overrides keyed by the names used to construct the source regions (e.g. the name argument of jno.domain or keys of from_polygons / from_regions).

TYPE: Optional[Mapping[str, float]] DEFAULT: None

interpolate

Controls how region_mesh_sizes is enforced inside each region (default True).

  • True (smooth interpolation, original behaviour) — the per-region size is set only on gmsh vertex points that lie on the region boundary; gmsh then smoothly interpolates the size field through the domain. For small or irregular regions the interior can end up substantially coarser than the requested region_mesh_sizes[name].
  • False (uniform refinement) — a per-region gmsh Box size field enforces the requested h_inner uniformly inside each region's axis-aligned bounding box. This produces dense homogeneous refinement (matches the requested density to within a few percent) at the cost of slight over-refinement for non-rectangular polygons (the Box field covers the bounding box, not the polygon shape). Use this when you actually need many FD-stencil nodes inside a small region.

TYPE: bool DEFAULT: True

Re-calling build_mesh re-meshes from scratch and clears the integral weight cache.


Neural-network controls

jno.nn.wrap lifts a plain Equinox / foundax module into a jNO Model so it can participate in the trace and accept per-model optimisers, masks, LoRA, freezing, and so on.

jno.architectures.models.nn

Neural network wrapping class for integrating modules into the jno pipeline.

Use nn.wrap(module) (or the shorthand nn(module)) to wrap an Equinox, Flax Linen, or Flax NNX module into a Model that works with jno.core.

Architecture factories have moved to the foundax package::

import foundax
model = jno.nn.wrap(foundax.mlp(2, hidden_dims=64, key=key))

wrap classmethod

wrap(module, space: None = ..., name: str = ..., weight_path: str = ...) -> Model
wrap(module, space: ArchSpace, name: str = ..., weight_path: str = ...) -> TunableModule
wrap(module: Any, space: ArchSpace = None, name: str = '', weight_path: str = None) -> Union[Model, TunableModule]

Wrap a module for use in the jno pipeline.

This is the primary method for integrating custom architectures into the jno framework. It handles both standard wrapping and architecture search scenarios.

PARAMETER DESCRIPTION
module

An eqx.Module instance (for standard use), a legacy Flax nn.Module, or a class (for architecture search).

TYPE: Any

space

Optional ArchSpace for hyperparameter tuning. When provided, module must be a class, not an instance.

TYPE: ArchSpace DEFAULT: None

name

Optional display name.

TYPE: str DEFAULT: ''

weight_path

Optional path to pretrained weights.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
Model

Standard wrapped module (when space=None).

TYPE: Union[Model, TunableModule]

TunableModule

Tunable module for architecture search (when space provided).

TYPE: Union[Model, TunableModule]

RAISES DESCRIPTION
ValueError

If space is provided but module is an instance.

Example

Wrap a custom equinox module

import foundax model = nn.wrap(foundax.mlp(2, output_dim=1, key=jax.random.PRNGKey(0)))

jno.trace.Model

Model(module: Any, name: str = '', weight_path: str | None = None)

Wrapper for user-defined Equinox models.

Allows using any Equinox module within the PINO tracing system. The module is initialized lazily when the input dimension is known.

Example - Direct call style (module takes separate arguments): class MLP(eqx.Module): ... def call(self, x, y, *, key=None): z = jnp.concat([x, y], axis=-1) ... return z

uv_net = pnp.nn.wrap(MLP(..., key=key))
u = uv_net(x, y)[..., 0]

Create a Model wrapper.

PARAMETER DESCRIPTION
module

An Equinox module instance (already constructed), or a callable / Flax nn.Module for backward compatibility.

TYPE: Any

optimizer

optimizer(opt_fn: Any)

Attach an optimizer to this model.

When preceded by mask(param_mask), the optimizer applies only to matching parameters; everything else uses the global optimizer (set via a bare optimizer() call)::

NN.mask(mask_decoder).optimizer(optax.adam)  # decoder group
NN.mask(mask_encoder).optimizer(optax.sgd)   # encoder group
NN.optimizer(optax.adam)                   # global fallback

Bake the learning rate into the optax optimizer (e.g. optax.adam(1e-3)); use :meth:scale to multiply it -- e.g. with a dlrs(...) schedule for loss-adaptive learning-rate scaling. mask(...) is one-shot, so to scale a masked group call mask(...) again before scale(...)::

NN.mask(mask_decoder).optimizer(optax.adam(1e-3))
NN.mask(mask_decoder).scale(my_schedule)

A bare/global call (not preceded by mask(...)) replaces any previously configured parameter groups.

PARAMETER DESCRIPTION
opt_fn

An optax optimizer factory, e.g. optax.adam, or an already-constructed transform.

TYPE: Any

freeze

freeze()

Mark this model as frozen (not trained).

When preceded by mask(...), only the currently selected parameters are frozen and everything else remains trainable::

NN.mask(param_mask).freeze()         # True leaves frozen, False leaves trainable
NN.freeze()                          # whole model frozen

Order matters: mask() must be called before freeze().

unfreeze

unfreeze()

Unfreeze this model so it is trained normally.

mask

mask(param_mask=None)

Set the current mask scope using an explicit boolean pytree mask.

param_mask must mirror the parameter tree structure and contain boolean leaves where True selects leaves in the masked scope.

This scope is consumed by grouped optimizer/lr calls and by mask(...).freeze(). It is also read by u.grad(net.mask(...)) to restrict the Jacobian to only the selected parameters.

Example::

import equinox as eqx, jax

all_false = jax.tree_util.tree_map(lambda _: False, model.module)
param_mask = eqx.tree_at(
    lambda m: (m.layers[0].weight, m.layers[0].bias),
    all_false, (True, True),
)
model.mask(param_mask).optimizer(optax.adam(1e-3))
J = crux.eval([u.grad(model.mask(param_mask))])[0]  # (N, P_selected)

lora

lora(rank: int = 4, alpha: float = 1.0, *, target: str | None = None, wrapper: type[LoRAWrapper] | Sequence[type[LoRAWrapper]] | None = None, specs: list[dict] | None = None)

Enable LoRA fine-tuning for this model.

Two calling conventions:

  1. Uniform::

    NN.lora(rank=8, alpha=16) NN.lora(rank=4, wrapper=MyConvAdapter) # custom adapter NN.lora(rank=4, wrapper=[LoRALinear, MyConv]) # tried in order

  2. Per-target — different rank/alpha/adapter per layer group::

    NN.lora(specs=[ {"target": "encoder", "rank": 4, "alpha": 1.0}, {"target": "conv", "rank": 8, "alpha": 2.0, "wrapper": MyConvAdapter}, ])

Each target is a regex matched against the pytree path. The first matching spec wins.

By default only the low-rank adapters are trained; base weights are frozen. Layers that are NOT wrapped by LoRA remain fully trainable. Call freeze() before lora() to also freeze any parameters outside LoRA-wrapped layers::

NN.freeze().lora(rank=8, alpha=16)

Use mask(M) to restrict which layers receive LoRA adapters::

NN.mask(M).lora(rank=8, alpha=16)  # only M-selected layers are wrapped
PARAMETER DESCRIPTION
rank

LoRA rank (uniform mode).

TYPE: int DEFAULT: 4

alpha

LoRA scaling factor (uniform mode).

TYPE: float DEFAULT: 1.0

target

Regex to restrict which layers get LoRA adapters (uniform mode only). Layers whose pytree path does not match are left completely untouched. Use specs= for per-group targeting.

TYPE: str | None DEFAULT: None

wrapper

Adapter class or list of classes to try in order. Defaults to (LoRALinear, LoRAConv) — wraps both linear and conv layers. Pass a single class or list to override.

TYPE: type[LoRAWrapper] | Sequence[type[LoRAWrapper]] | None DEFAULT: None

specs

Per-target specs (per-target mode). Each dict has keys target (str regex), rank (int), alpha (float), and optionally wrapper.

TYPE: list[dict] | None DEFAULT: None

dtype

dtype(dtype: Any) -> 'Model'

Set this model's working dtype (parameters and compute).

Casts all floating-point parameters to dtype and — at the forward seam — casts the model's inputs to match, so the network actually computes in dtype rather than promoting back to float32. The cast is symmetric: it lowers (float32 → bfloat16) and promotes (load a bfloat16 checkpoint, then .dtype(jnp.float32)), and applies to both training and inference. Integer arrays (e.g. indices) are left unchanged.

This is the model-precision knob. Data precision (float32 vs float64) is JAX's jax_enable_x64 flag — not a jNO setting. Enable it before building models/domains (JAX_ENABLE_X64=1 or jax.config.update("jax_enable_x64", True)).

PARAMETER DESCRIPTION
dtype

A JAX floating dtype object, e.g. jnp.bfloat16, jnp.float16, jnp.float32 or jnp.float64.

TYPE: Any

Caveats
  • bfloat16 compute degrades autodiff derivatives (.laplacian / .hessian) — keep derivative-critical (PINN) models in float32 and opt only data-loss / operator backbones into bf16.
  • bfloat16 parameters mean the optimizer update also runs in bfloat16, which can stall on very small updates.

Example::

backbone.dtype(jnp.bfloat16)   # real bf16 compute for this model
pinn_net.dtype(jnp.float32)    # keep its derivatives full precision

constrain

constrain(transform: Callable) -> 'Model'

Apply a paramax reparameterization to trainable parameter leaves.

Parameters are stored in their unconstrained form and transformed by transform before every forward pass via paramax.unwrap(), which jno's training loop calls automatically.

When preceded by mask(...), only leaves where the mask is True are wrapped — all other leaves remain unconstrained::

k_net.mask(output_mask).constrain(jax.nn.softplus)  # output layer only
k_net.constrain(jax.nn.softplus)                    # all parameters
PARAMETER DESCRIPTION
transform

A jit-compatible callable (e.g. jax.nn.softplus, jax.nn.sigmoid).

TYPE: Callable

RETURNS DESCRIPTION
'Model'

self (for chaining)

initialize

initialize(weights: Any, *, key: Any = None) -> 'Model'

Load pretrained weights into this model at init time.

    Accepted ``weights`` inputs:

    - ``str`` / ``Path``: load from checkpoint path.
        Supports Equinox ``.eqx`` files and Orbax checkpoint directories
        (optionally ``"<path>::<model_key>"``).
    - Pytree object: copy array leaves directly from the provided tree.
    - Callable initializer: apply a JAX initializer function to every
        floating-point array leaf at compile time.

    Examples:

    .. code-block:: python

            net.initialize("./weights.eqx")
            net.initialize("./runs/ckpts/2000::1")
            net.initialize(other_model.module)

            p = jno.np.parameter((1,), key=jax.random.PRNGKey(0))
            p.initialize(jax.nn.initializers.ones)
PARAMETER DESCRIPTION
weights

File path / pytree / callable initializer.

TYPE: Any

key

Optional PRNG key used when weights is callable.

TYPE: Any DEFAULT: None

RETURNS DESCRIPTION
'Model'

self (for chaining).

tune

tune(*, freeze: list | None = None, lora: list | None = None, optimizer: list | None = None, lr: list | None = None, dtype: list | None = None) -> 'Model'

Declare per-model tunable options for hyperparameter sweeps.

Each argument accepts a list of candidate values. During a sweep the tuner searches over all combinations.

PARAMETER DESCRIPTION
freeze

List of bool, e.g. [True, False].

TYPE: list | None DEFAULT: None

lora

List of (rank, alpha) tuples or None values, e.g. [(4, 1.0), (8, 1.0), None].

TYPE: list | None DEFAULT: None

optimizer

List of optax factories, e.g. [optax.adam].

TYPE: list | None DEFAULT: None

lr

List of :class:LearningRateSchedule objects.

TYPE: list | None DEFAULT: None

dtype

List of dtypes, e.g. [jnp.float32, jnp.bfloat16].

TYPE: list | None DEFAULT: None

RETURNS DESCRIPTION
'Model'

self (for chaining).

Example::

backbone = nn.poseidon(...)
backbone.initialize("weights.msgpack")
backbone.tune(
    freeze=[True, False],
    lora=[(4, 1.0), None],
    optimizer=[optax.adam],
    lr=[lrs.constant(1e-4), lrs.constant(1e-5)],
)

Symbolic math (jno.np)

A NumPy-compatible namespace that returns traced placeholders instead of concrete arrays. Use it inside any expression that you intend to feed into jno.core(...).

jno.jnp_ops

sin module-attribute

sin = _unary(jnp.sin)

cos module-attribute

cos = _unary(jnp.cos)

exp module-attribute

exp = _unary(jnp.exp)

log module-attribute

log = _unary(jnp.log)

sqrt module-attribute

sqrt = _unary(jnp.sqrt)

abs module-attribute

abs = _unary(jnp.abs)

pi module-attribute

pi = jnp.pi

concat

concat(items, axis: int = -1) -> FunctionCall

Concatenate placeholders along an axis (always axis=-1 at eval time).

grad

grad(target: Placeholder, variable: Variable, scheme: str = 'automatic_differentiation') -> Jacobian

Compute the gradient of target with respect to variable.

Implemented as a single-variable Jacobian.

Prefer the method-style shorthand on the target expression::

u_x  = u.d(x)          # ∂u/∂x
u_xx = u.d(x).d(x)     # ∂²u/∂x² (chainable)
PARAMETER DESCRIPTION
target

Expression to differentiate

TYPE: Placeholder

variable

Variable to differentiate with respect to

TYPE: Variable

scheme

'automatic_differentiation' (default) or 'finite_difference'

TYPE: str DEFAULT: 'automatic_differentiation'

RETURNS DESCRIPTION
Jacobian

Jacobian placeholder representing ∂target/∂variable

Example

u_x = pnp.grad(u(x, y), x) # ∂u/∂x


Function helpers and loss balancers (jno.fn)

jno.fn provides PDE-named helpers (heat, wave, burgers_1d, ...), loss reductions (mse, mae, rmse, huber, log_cosh, ...), and the adaptive loss balancers under jno.fn.adaptive.*.

jno.fn

Functional helpers for traced expressions: jno.fn.sin(u), jno.fn.mse(pred, target).

This module is callable — jno.fn(my_func, [arg1, arg2]) wraps an arbitrary function into the tracing graph (replaces jno.np.function).

Sections
  • Math: sin, cos, exp, log, sqrt, abs, …
  • Losses: mse, mae, rmse, huber, log_cosh, relative_l2
  • PDEs: poisson, heat, wave, burgers_1d, navier_stokes_incompressible_2d, …
Examples

import jno pde = jno.fn.sin(u) + jno.fn.exp(-x) loss = jno.fn.mse(pred, target) custom = jno.fn(lambda a, b: a ** 2 + b, [u, v])

_module_call

_module_call(fn: Callable, args: list = [], name: str = '', reduces_axis: Optional[int] = None) -> FunctionCall

Wrap an arbitrary function into the tracing graph.

PARAMETER DESCRIPTION
fn

Any callable (*args) -> array.

TYPE: Callable

args

Traced placeholder arguments.

TYPE: list DEFAULT: []

name

Optional display name in the expression tree.

TYPE: str DEFAULT: ''

reduces_axis

If the function reduces an axis, specify it here.

TYPE: Optional[int] DEFAULT: None

RETURNS DESCRIPTION
FunctionCall

FunctionCall placeholder node.

Example::

custom = jno.fn(lambda a, b: a ** 2 + b, [u, v], name="my_op")

Training history

solve() returns a statistics object. The most common operations:

jno.utils.statistics.statistics

statistics(logs)

Training history returned by core.solve().

Access patterns::

history = crux.solve(...)
history.total_loss            # final scalar total loss
history.total_loss_history    # 1-D array of total loss per epoch
history.training_logs         # list of per-solve()-call dicts
history.training_logs[-1]["total_loss"]   # full array from last call
history.plot("./runs/loss.png")           # quick visualization
history.summary()                          # printed run summary

total_loss property

total_loss

Final scalar total loss (last value across all solve() calls).

Returns None when no training has been recorded.

total_loss_history property

total_loss_history: ndarray

1-D array of total loss concatenated across all solve() calls.

plot

plot(path: str = None) -> statistics

Plot training statistics from all solve() calls.

Creates a multi-panel figure showing: - Constraint losses over time (individual lines; total added when >1 constraint) - Tracker values over time (if any trackers were defined) - Step time in milliseconds (derived from log timestamps)

PARAMETER DESCRIPTION
path

Path to save the figure (e.g. "./runs/training.png").

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
statistics

self (for chaining)

save

save(path: str)

load classmethod

load(filepath: str) -> statistics

Load a trained core model from a file.

Restores all trained parameters, operations, domain, and history.

PARAMETER DESCRIPTION
filepath

Path to saved model file

TYPE: str

RETURNS DESCRIPTION
statistics

core instance with trained parameters

Example

sol = core.load("trained_model.pkl")


Differential and integral operators

These provide the residuals you put inside constraints (u.laplacian(x, y), u.d(x), (grad_u * n).integrate()).

Scheme strings

Every differential operator (.d, .diff, .d2, .dd, .laplacian, .hessian) accepts a scheme= kwarg that selects the backend:

Scheme Backend
"automatic_differentiation" (default) global default — see jno.setup(diff_type=..., hessian_type=...)
"automatic_differentiation:forward" first-order via jax.jacfwd
"automatic_differentiation:reverse" first-order via jax.jacrev
"automatic_differentiation:fwd-over-rev" second-order jacfwd(jacrev(f)) (= historical jax.hessian)
"automatic_differentiation:fwd-over-fwd" second-order jacfwd(jacfwd(f))
"automatic_differentiation:rev-over-rev" second-order jacrev(jacrev(f))
"automatic_differentiation:rev-over-fwd" second-order jacrev(jacfwd(f))
"finite_difference" central-difference stencils on mesh (with :lsq / :uniform / :inverse_distance / :cotangent sub-schemes)

Forward-mode is typically cheaper when the input dim (≤ 3 spatial dims for PINNs) is ≤ the output dim; reverse-mode is cheaper for scalar losses with many inputs. Set the project-wide default once via .jno.toml:

[jno]
diff_type    = "forward"        # default for first-order operators
hessian_type = "fwd-over-rev"   # default for second-order operators

or per script via jno.setup(__file__, diff_type="forward"). Per-call scheme= always overrides the default.

jno.differential_operators.DifferentialOperators

Static collection of mesh-based FD operators (1-D, 2-D, 3-D).

All public methods are static — the class is used purely as a namespace. See the module docstring for full method descriptions.

compute_fd_gradient_1d_simple staticmethod

compute_fd_gradient_1d_simple(u_values: ndarray, points: ndarray, lines: ndarray, method: str = 'area_weighted') -> jnp.ndarray

Gradient on a 1-D line mesh.

PARAMETER DESCRIPTION
u_values

Function values at mesh points, shape (N,).

TYPE: ndarray

points

Mesh point coordinates, shape (N, 1) or (N,).

TYPE: ndarray

lines

Line element connectivity, shape (M, 2).

TYPE: ndarray

method

Weighting strategy — one of "area_weighted" (default), "uniform", "inverse_distance", "least_squares".

TYPE: str DEFAULT: 'area_weighted'

RETURNS DESCRIPTION
ndarray

du/dx at each point, shape (N,).

compute_fd_laplacian_1d_simple staticmethod

compute_fd_laplacian_1d_simple(u_values: ndarray, points: ndarray, lines: ndarray, method: str = 'gradient_of_gradient') -> jnp.ndarray

Laplacian on a 1-D line mesh.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 1) or (N,).

TYPE: ndarray

lines

Line connectivity, shape (M, 2).

TYPE: ndarray

method

"gradient_of_gradient" (only option in 1-D).

TYPE: str DEFAULT: 'gradient_of_gradient'

RETURNS DESCRIPTION
ndarray

d²u/dx² at each point, shape (N,).

compute_fd_hessian_1d_simple staticmethod

compute_fd_hessian_1d_simple(u_values: ndarray, points: ndarray, lines: ndarray, var_dims: list | None = None) -> jnp.ndarray

Hessian (= d²u/dx²) on a 1-D line mesh.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 1) or (N,).

TYPE: ndarray

lines

Line connectivity, shape (M, 2).

TYPE: ndarray

var_dims

Optional [(i, vi_dim, j, vj_dim), …].

TYPE: list | None DEFAULT: None

RETURNS DESCRIPTION
ndarray

Hessian, shape (N, 1, 1).

compute_fd_gradient_2d_simple staticmethod

compute_fd_gradient_2d_simple(u_values: ndarray, points: ndarray, triangles: ndarray, dim: int, method: str = 'area_weighted') -> jnp.ndarray

Gradient on a 2-D triangular mesh.

PARAMETER DESCRIPTION
u_values

Function values at mesh points, shape (N,).

TYPE: ndarray

points

Mesh point coordinates, shape (N, 2).

TYPE: ndarray

triangles

Triangle connectivity, shape (M, 3).

TYPE: ndarray

dim

Spatial dimension to differentiate (0 = x, 1 = y).

TYPE: int

method

"area_weighted" (default), "uniform", "inverse_distance", "least_squares".

TYPE: str DEFAULT: 'area_weighted'

RETURNS DESCRIPTION
ndarray

∂u/∂x_dim at each point, shape (N,).

compute_gradient_2d_lsq staticmethod

compute_gradient_2d_lsq(u_values: ndarray, points: ndarray, triangles: ndarray, dim: int) -> jnp.ndarray

Least-squares gradient on a 2-D triangular mesh.

For each node i the gradient is estimated by solving a 2×2 area-weighted least-squares problem built from incident triangle centroids. The 2×2 system is solved via Cramer's rule so the entire computation uses only JAX scatter-add operations with no per-node Python loops.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 2).

TYPE: ndarray

triangles

Triangle connectivity, shape (M, 3).

TYPE: ndarray

dim

0 → ∂u/∂x, 1 → ∂u/∂y.

TYPE: int

RETURNS DESCRIPTION
ndarray

Gradient component at each node, shape (N,).

compute_fd_laplacian_2d_simple staticmethod

compute_fd_laplacian_2d_simple(u_values: ndarray, points: ndarray, triangles: ndarray, dims: tuple, method: str = 'gradient_of_gradient') -> jnp.ndarray

Laplacian on a 2-D triangular mesh.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 2).

TYPE: ndarray

triangles

Triangle connectivity, shape (M, 3).

TYPE: ndarray

dims

Spatial dimensions to sum over, e.g. (0, 1).

TYPE: tuple

method

"gradient_of_gradient" (default), "cotangent", or "lsq_of_gradient".

TYPE: str DEFAULT: 'gradient_of_gradient'

RETURNS DESCRIPTION
ndarray

Laplacian, shape (N,).

compute_laplacian_2d_cotangent staticmethod

compute_laplacian_2d_cotangent(u_values: ndarray, points: ndarray, triangles: ndarray) -> jnp.ndarray

Cotangent-weight (Laplace–Beltrami) Laplacian on a 2-D mesh.

For each triangle (i, j, k) the cotangent of each interior angle is used to weight the edge contributions::

lap[i] += (1/A_i) * [ cot_k*(u_j - u_i) + cot_j*(u_k - u_i) ]

with cot_k = cotangent of the angle at vertex k (opposite edge (i,j)), and A_i = (1/3) * Σ area.

This is second-order accurate and isotropic; it is the gold standard for PDE discretisation on unstructured 2-D triangular meshes.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 2).

TYPE: ndarray

triangles

Triangle connectivity, shape (M, 3).

TYPE: ndarray

RETURNS DESCRIPTION
ndarray

Laplacian at each point, shape (N,).

compute_fd_hessian_2d_simple staticmethod

compute_fd_hessian_2d_simple(u_values: ndarray, points: ndarray, triangles: ndarray, var_dims: list) -> jnp.ndarray

Hessian on a 2-D triangular mesh (area-weighted FD).

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 2).

TYPE: ndarray

triangles

Triangle connectivity, shape (M, 3).

TYPE: ndarray

var_dims

List of (i, vi_dim, j, vj_dim) tuples.

TYPE: list

RETURNS DESCRIPTION
ndarray

Hessian, shape (N, n_vars, n_vars).

compute_fd_gradient_3d_simple staticmethod

compute_fd_gradient_3d_simple(u_values: ndarray, points: ndarray, tetrahedra: ndarray, dim: int, method: str = 'area_weighted') -> jnp.ndarray

Gradient on a 3-D tetrahedral mesh.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Mesh point coordinates, shape (N, 3).

TYPE: ndarray

tetrahedra

Tet connectivity, shape (M, 4).

TYPE: ndarray

dim

Spatial dimension (0, 1 or 2).

TYPE: int

method

"area_weighted" (default), "uniform", "inverse_distance", "least_squares".

TYPE: str DEFAULT: 'area_weighted'

RETURNS DESCRIPTION
ndarray

∂u/∂x_dim at each point, shape (N,).

compute_gradient_3d_lsq staticmethod

compute_gradient_3d_lsq(u_values: ndarray, points: ndarray, tetrahedra: ndarray, dim: int) -> jnp.ndarray

Least-squares gradient on a 3-D tetrahedral mesh.

Analogous to :meth:compute_gradient_2d_lsq but solves a 3×3 normal-equation system at each node via Cramer's rule.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 3).

TYPE: ndarray

tetrahedra

Tet connectivity, shape (M, 4).

TYPE: ndarray

dim

0 → ∂u/∂x, 1 → ∂u/∂y, 2 → ∂u/∂z.

TYPE: int

RETURNS DESCRIPTION
ndarray

Gradient component at each node, shape (N,).

compute_fd_laplacian_3d_simple staticmethod

compute_fd_laplacian_3d_simple(u_values: ndarray, points: ndarray, tetrahedra: ndarray, dims: tuple, method: str = 'gradient_of_gradient') -> jnp.ndarray

Laplacian on a 3-D tetrahedral mesh.

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 3).

TYPE: ndarray

tetrahedra

Tet connectivity, shape (M, 4).

TYPE: ndarray

dims

Spatial dimensions to sum over, e.g. (0, 1, 2).

TYPE: tuple

method

"gradient_of_gradient" (default) or "lsq_of_gradient".

TYPE: str DEFAULT: 'gradient_of_gradient'

RETURNS DESCRIPTION
ndarray

Laplacian, shape (N,).

compute_fd_hessian_3d_simple staticmethod

compute_fd_hessian_3d_simple(u_values: ndarray, points: ndarray, tetrahedra: ndarray, var_dims: list) -> jnp.ndarray

Hessian on a 3-D tetrahedral mesh (volume-weighted FD).

PARAMETER DESCRIPTION
u_values

Function values, shape (N,).

TYPE: ndarray

points

Coordinates, shape (N, 3).

TYPE: ndarray

tetrahedra

Tet connectivity, shape (M, 4).

TYPE: ndarray

var_dims

List of (i, vi_dim, j, vj_dim) tuples.

TYPE: list

RETURNS DESCRIPTION
ndarray

Hessian, shape (N, n_vars, n_vars).

parse_fd_scheme staticmethod

parse_fd_scheme(scheme: str) -> tuple[str, str, str]

Parse a scheme string into (main_scheme, grad_method, lap_method).

Supported formats::

"finite_difference"                  → fd, "area_weighted", "gradient_of_gradient"
"finite_difference:lsq"              → fd, "least_squares", "lsq_of_gradient"
"finite_difference:cotangent"        → fd, "area_weighted", "cotangent"
"finite_difference:uniform"          → fd, "uniform",       "gradient_of_gradient"
"finite_difference:inverse_distance" → fd, "inverse_distance", "gradient_of_gradient"
"automatic_differentiation"          → ad, None, None
RETURNS DESCRIPTION
tuple[str, str, str]

Tuple (main_scheme, grad_method, lap_method).

jno.integration_operators.IntegrationOperators

Static namespace for mesh-based numerical integration.

Works on the mesh_connectivity dict produced by the domain class. Boundary weights (nodal_ds) are already stored there; this class adds volume weights (nodal_volumes) computed on the fly.

nodal_volumes staticmethod

nodal_volumes(mesh_connectivity: dict) -> np.ndarray

Per-node volume weights for interior integration.

Returns mesh_connectivity["nodal_volumes"] if it was precomputed during domain setup (the normal path). Otherwise computes on the fly (fallback for manually constructed mesh_connectivity dicts).

Each node receives a share of surrounding element volumes:

  • 1-D: ½ × sum of adjacent segment lengths (trapezoidal rule)
  • 2-D: ⅓ × sum of incident triangle areas
  • 3-D: ¼ × sum of incident tetrahedron volumes
Parameters

mesh_connectivity : dict Preprocessed mesh connectivity from the domain class.

Returns

vols : ndarray of shape (n_points,)

jno.utils.ad_mode

AD mode (forward / reverse) selection for jNO operators.

Two layers, in order of precedence:

  1. Per-call scheme suffix on the operator:

  2. First-order (.d, .diff, d/dt)::

    u.d(x, scheme="automatic_differentiation:forward") u.d(x, scheme="automatic_differentiation:reverse")

  3. Second-order (.laplacian, .hessian, .d2, .dd)::

    u.laplacian(x, y, scheme="automatic_differentiation:fwd-over-rev") u.laplacian(x, y, scheme="automatic_differentiation:fwd-over-fwd") u.laplacian(x, y, scheme="automatic_differentiation:rev-over-rev") u.laplacian(x, y, scheme="automatic_differentiation:rev-over-fwd")

  4. Global default — set via :func:jno.setup or via .jno.toml::

    jno.setup(file, diff_type="forward", hessian_type="fwd-over-fwd")

.. code-block:: toml

   [jno]
   diff_type    = "forward"        # first-order default
   hessian_type = "fwd-over-rev"   # second-order default

The plain string "automatic_differentiation" (no suffix) resolves to the current global default. Defaults match historical behaviour: first-order reverse (was jax.jacobian = jacrev); second-order fwd-over-rev (was jax.hessian = jacfwd ∘ jacrev).

set_ad_mode

set_ad_mode(mode: str) -> None

Set the global default for first-order AD.

get_ad_mode

get_ad_mode() -> str

Return the global default for first-order AD.

set_hessian_mode

set_hessian_mode(mode: str) -> None

Set the global default for second-order AD.

get_hessian_mode

get_hessian_mode() -> str

Return the global default for second-order AD.

parse_ad_scheme

parse_ad_scheme(scheme: str) -> str

Resolve a first-order scheme string to "forward" or "reverse".

Supported::

"automatic_differentiation"          → global default (get_ad_mode())
"automatic_differentiation:forward"  → "forward"
"automatic_differentiation:reverse"  → "reverse"

parse_hessian_scheme

parse_hessian_scheme(scheme: str) -> tuple[str, str]

Resolve a second-order scheme string to (outer, inner) AD modes.

The result composes as outer(inner(f)). E.g. ("forward", "reverse") means jax.jacfwd(jax.jacrev(f)) — the historical jax.hessian path.

Supported::

"automatic_differentiation"                → global default (get_hessian_mode())
"automatic_differentiation:fwd-over-rev"   → ("forward", "reverse")
"automatic_differentiation:fwd-over-fwd"   → ("forward", "forward")
"automatic_differentiation:rev-over-rev"   → ("reverse", "reverse")
"automatic_differentiation:rev-over-fwd"   → ("reverse", "forward")

First-order suffixes forward/reverse are accepted as shorthand for the matching same-mode composition (forwardfwd-over-fwd).

ad_fn

ad_fn(mode: str)

Return jax.jacfwd or jax.jacrev for the given mode.


Tracing primitives

Most users never instantiate these directly — they are what the expression-building API returns. Documented here for reference and for authors of custom operators.

jno.trace.Variable

Variable(tag: str, dim: list, domain: Any, axis: str = 'spatial', fem_meta: dict | None = None)

Independent variable placeholder (e.g., x, y, t).

Carries the domain tag and dimension index so the solver can bind sampled coordinates when evaluating traced expressions.

For time-dependent problems, spatial variables (axis='spatial') index into the spatial context array context[tag] shaped (N, D_spatial) (after the outer B and T vmaps peel off their axes). The temporal variable (axis='temporal') reads from a separate context["__time__"] entry that is a scalar (after the T vmap).

jno.trace.Integral

Integral(target: 'Placeholder', integration_var: 'Variable | None' = None)

Mesh-based integral reduction of an expression over its domain region.

Created by :meth:Placeholder.integrate. The region (boundary vs volume) is auto-detected at evaluation time from the Variable tags inside target via domain._boundary_registry.

When integration_var is set (the outer/collocation Variable), the evaluator uses jax.vmap to return an (N, 1) array instead of a scalar, enabling non-separable Fredholm kernels.

jno.trace.Noise

Noise(distribution: str, **params)

Stochastic noise term regenerated every training step.

Created by :mod:jno.noise. Produces an array of shape (N, ndim) where N is inferred at evaluation time from the number of active spatial points and ndim (default 1) controls the trailing dimension.

The realisation is derived from the solver's step PRNG key via jax.random.fold_in, so it is fully reproducible when the global seed is fixed (via :func:jno.setup or .jno.toml).

Parameters

distribution : str 'gaussian', 'uniform', or 'laplace'. **params Distribution-specific kwargs: std, low, high, ndim.