Skip to content

LoRA

LoRA inserts trainable low-rank adapter matrices into matching layers while keeping base weights frozen. By default both linear and conv layers are wrapped automatically.

Build your own adapter

Any subclass of LoRAWrapper is a valid adapter — implement applies_to(cls, leaf) (which leaf types to wrap), __init__(base, rank, alpha, *, key) (build the adapter parameters around the wrapped module), __call__(x) (forward pass that combines base + adapter contribution), and merge() (fold the adapter back into a single weight tensor for inference).

from jno.lora import LoRAWrapper
import equinox as eqx
import jax, jax.numpy as jnp

class MyEmbeddingAdapter(LoRAWrapper):
    adapter_fields = ("delta",)
    base: eqx.Module
    delta: jax.Array
    rank: int = eqx.field(static=True)
    alpha: float = eqx.field(static=True)

    @classmethod
    def applies_to(cls, leaf):
        return isinstance(leaf, eqx.nn.Embedding) and not isinstance(leaf, LoRAWrapper)

    def __init__(self, base, rank, alpha, *, key):
        self.base, self.rank, self.alpha = base, rank, alpha
        self.delta = jnp.zeros_like(base.weight)

    def __call__(self, x):
        return self.base(x) + self.delta[x] * (self.alpha / self.rank)

    def merge(self):
        w = self.base.weight + self.delta * (self.alpha / self.rank)
        return eqx.tree_at(lambda m: m.weight, self.base, w)

net.lora(rank=4, wrapper=MyEmbeddingAdapter)
net.lora(rank=4, wrapper=[LoRALinear, LoRAConv, MyEmbeddingAdapter])

Per-target specs may carry their own "wrapper" key:

net.lora(
    specs=[
        {"target": "linear", "rank": 4, "alpha": 1.0},
        {"target": "embed",  "rank": 8, "alpha": 2.0, "wrapper": MyEmbeddingAdapter},
    ]
)

The built-in adapters in the LoRA Zoo below all subclass LoRAWrapper themselves — they cover the published variants (DoRA, PiSSA, VeRA, …); your own subclass plugs in the same way.

Selecting Layers

net.lora(rank=8, target="encoder")                         # path-regex
net.mask(encoder_mask).lora(rank=8)                        # boolean mask
net.mask(encoder_mask).lora(rank=8, target="encoder")      # both combined
target= / specs= mask(M).lora()
Selects by regex on pytree path string boolean pytree
Use when you know the layer names up front you have a precomputed mask

Uniform and Per-Target Specs

net.lora(rank=8, alpha=16)                    # all layers
net.lora(rank=8, alpha=16, target="encoder")  # restricted to a subset

net.lora(
    specs=[
        {"target": "encoder", "rank": 4,  "alpha": 1.0},
        {"target": "decoder", "rank": 16, "alpha": 4.0},
    ]
)

target is regex-matched against the slash-joined pytree path. The first matching spec wins.

Combining with Mask and Freeze

net.mask(encoder_mask).lora(rank=8, alpha=16)           # only mask-selected layers
net.freeze().lora(rank=8, alpha=16)                     # freeze all; only adapters train
net.freeze().mask(encoder_mask).lora(rank=8, alpha=16)  # wrap M-selected, freeze rest

Default Layer Types

Without wrapper=, jNO wraps Linear and Conv layers (ConvTranspose excluded):

net.lora(rank=4, alpha=1.0)           # Linear + Conv1d/2d/3d
net.lora(rank=4, wrapper=LoRALinear)  # linear only
net.lora(rank=4, wrapper=LoRAConv)    # conv only

LoRA Zoo — Linear

from jno.lora import (
    LoRALinear,    # standard LoRA (default)
    rsLoRALinear,  # rank-stabilized
    LoRAFALinear,  # frozen A — fewer trainable params
    DoRALinear,    # weight-decomposed
    PiSSALinear,   # SVD init — fastest convergence on pretrained models
    LoRAXSLinear,  # extra-small r×r core
    VeRALinear,    # frozen random A,B; only b,d vectors trained
    MiLoRALinear,  # minor SVD components — preserves pretrained knowledge
    IA3Linear,     # output scaling vector — no low-rank matrices
    LoKrLinear,    # Kronecker product adapter
    OFTLinear,     # block-diagonal orthogonal fine-tuning
)
Class Trainable params Key idea
LoRALinear r·(in + out) Standard LoRA; scale = α/r
rsLoRALinear r·(in + out) Scale = α/√r — stable across ranks (rsLoRA)
LoRAFALinear r·out Frozen A; halves adapter params (LoRAFA)
DoRALinear r·(in + out) + out Magnitude + direction decomposition (DoRA)
PiSSALinear r·(in + out) SVD principal-component init (PiSSA)
LoRAXSLinear Frozen A,B; trainable r×r core only (LoRA-XS)
VeRALinear out + r Seed-based frozen A,B; only b,d vectors trained (VeRA)
MiLoRALinear r·(in + out) Adapts minor SVD components (MiLoRA)
IA3Linear out Per-output scale vector; no rank hyperparameter (IA³)
LoKrLinear r² + ⌈out/r⌉·⌈in/r⌉ Kronecker product adapter (LoKr)
OFTLinear n_blocks·r² Orthogonal fine-tuning via Cayley map (OFT)

When to use which:

  • rsLoRALinear — default upgrade; use higher ranks without numerical issues.
  • LoRAFALinear — memory-constrained; halves adapter parameter count.
  • DoRALinear — pretrained models where preserving weight norms matters.
  • PiSSALinear — pretrained models; adapters start at the most informative directions.
  • LoRAXSLinear — extreme parameter efficiency.
  • VeRALinear — fewest trainable params; A, B not stored in checkpoints.
  • MiLoRALinear — preserves principal directions; adapts the noise subspace.
  • IA3Linear — no rank hyperparameter; ideal for fast probing.
  • LoKrLinear — large layers where Kronecker factorisation is more efficient.
  • OFTLinear — orthogonal weight updates to preserve geometry.

LoRA Zoo — Conv

Matching conv variants for eqx.nn.Conv1d/2d/3d. All flatten the weight to (out_ch, flat_in), apply the same adapter logic, and reshape back.

from jno.lora import (
    LoRAConv, rsLoRAConv, LoRAFAConv, DoRAConv, PiSSAConv,
    LoRAXSConv, VeRAConv, MiLoRAConv, IA3Conv, LoKrConv, OFTConv,
)
Class Trainable params Key idea
LoRAConv r·(flat_in + out_ch) Standard LoRA on flattened conv weight
rsLoRAConv r·(flat_in + out_ch) Rank-stabilized scaling α/√r
LoRAFAConv r·out_ch Frozen A; only B trained
DoRAConv r·(flat_in + out_ch) + out_ch Magnitude + direction decomposition
PiSSAConv r·(flat_in + out_ch) SVD principal components init
LoRAXSConv Frozen A,B from SVD; trainable r×r core
VeRAConv out_ch + r Seed-based frozen A,B; only b,d vectors trained
MiLoRAConv r·(flat_in + out_ch) SVD minor components
IA3Conv out_ch Per-output-channel scale vector
LoKrConv r² + ⌈out_ch/r⌉·⌈flat_in/r⌉ Kronecker product adapter
OFTConv n_blocks·r² Block-diagonal Cayley map on output channels

Mix linear and conv adapters per layer group:

net.lora(
    specs=[
        {"target": "encoder", "rank": 8,  "alpha": 16,  "wrapper": PiSSALinear},
        {"target": "decoder", "rank": 4,  "alpha": 1.0, "wrapper": rsLoRALinear},
        {"target": "conv",    "rank": 4,  "alpha": 1.0, "wrapper": rsLoRAConv},
    ]
)