Tensor types¶
NEML2 evaluates constitutive models on batched tensor data. The tensor
backend is PyTorch — every tensor in the system is a torch.Tensor at
the storage level — but NEML2 wraps each tensor in a typed wrapper
that carries a fixed mathematical structure (a scalar, a vector, a
symmetric second-order tensor, …) and a small amount of batching
metadata.
This page documents the wrappers, the shape conventions, and the batching rules. It is the reference complement to Vectorization, which works through a single worked example.
Shape decomposition¶
Every typed wrapper exposes its underlying tensor as the .data
attribute. The shape of that tensor splits into three contiguous
regions:
data.shape == (*dynamic_batch_shape, *sub_batch_shape, *base_shape)
└── leading ──┘└── middle (static) ──┘└─ BASE_NDIM ─┘
Base shape — the trailing
BASE_NDIMaxes. Fixed by the wrapper type (e.g.(6,)forSR2,()forScalar); these encode the mathematical structure and never participate in broadcasting.Sub-batch shape — the next
sub_batch_ndimaxes. A small, static batching region used to express per-site structure (lookup-table axes, finite-volume cells, slip systems, …). The chain-rule machinery treats sub-batch dims specially so derivatives stay consistent when models on different sites are composed. Default 0 — most models don’t need it.Dynamic batch shape — everything left over. Free-form, sized at call time, traced as dynamic by
torch.exportso a single AOTI artifact handles every batch size from 1 to roughly a million without recompilation.
wrapper.batch_shape returns dynamic_batch_shape + sub_batch_shape
— “everything that isn’t base”. wrapper.batch.ndim,
wrapper.dynamic_batch.ndim, and wrapper.sub_batch_ndim provide the
corresponding counts.
Fixed-base-shape tensor types¶
The following wrappers ship in neml2.types. The class hierarchy is:
TensorWrapper (abstract — shape decomposition + region views)
└── PrimitiveTensor (concrete intermediate — generic ops + factories)
├── Scalar
├── Vec, R2, SR2, WR2, Rot, SSR4, MillerIndex
PrimitiveTensor is the layer where the generic arithmetic operators
(+, -, *, /, -x) and shape factories (zeros, ones, full,
empty, fill) are defined. Each concrete leaf below it adds any
class-specific factories — e.g. R2.identity, SSR4.identity_sym,
SR2.fill with Mandel √2 scaling.
Type |
Base shape |
Storage / convention |
|---|---|---|
|
|
A single number per batch entry. The wrapper exists so mixed operations like |
|
|
3-vector. |
|
|
Modified Rodrigues parameters (MRPs) representing a 3D rotation: |
|
|
Integer-coordinate crystallographic direction or plane normal, stored as float for differentiability. |
|
|
Full second-order tensor (no symmetry). |
|
|
Skew-symmetric second-order tensor stored as an axial 3-vector |
|
|
Symmetric second-order tensor in Mandel notation: |
|
|
Fourth-order tensor with both pairs of minor symmetries (the symmetry class of the elasticity tensor |
Scalar defaults to torch.float64 for both Python-literal construction
and its factory methods; the other wrappers accept a dtype= kwarg and
fall through to torch’s global default otherwise. Construct them from a
raw torch.Tensor (SR2.fill(0.01, 0, 0, 0, 0, 0)), from a
Python literal (Scalar(200e3)), or via the inherited factories
(Vec.zeros(N), SR2.fill(σ11, σ22, σ33, σ23, σ13, σ12), etc.).
Batching¶
A model’s forward operator dispatches one PyTorch kernel per
mathematical operation. Those kernels broadcast across leading batch
axes natively, so the same forward(x) body handles a single state
and a multi-million-state batch with no rewrites — the only difference
is the shape of the input tensors.
The canonical pattern:
strain_single = SR2.fill(0.01, 0, 0, 0, 0, 0) # base only
stress = model(strain_single)
# stress.data.shape == (6,)
strain_batch = SR2(torch.randn(10_000, 6) * 0.01) # one leading batch dim
stress = model(strain_batch)
# stress.data.shape == (10_000, 6)
strain_grid = SR2(torch.randn(50, 200, 6) * 0.01) # two leading batch dims
stress = model(strain_grid)
# stress.data.shape == (50, 200, 6)
Leading batch dims are completely free-form. A Python loop around a single-state model call leaves a lot of throughput on the table — pass the whole batch as one tensor whenever you can. Vectorization shows the timing difference.
Broadcasting¶
Inside forward, binary operators between typed wrappers broadcast
their batch regions using PyTorch’s standard rules:
Right-aligned by axis (so a
(6,)SR2 and a(N, 6)SR2 broadcast to(N, 6)).Size-1 axes are stretched to match.
Mismatched non-1 sizes are an error.
Algebraic operators preserve sub-batch metadata: every binary op
unifies the two operands’ sub-batch widths and tags the result
accordingly. The upshot is that a “global” Scalar parameter and a
“per-site” Scalar field combine cleanly at any dynamic batch size.
Dynamic vs sub-batch dimensions¶
These are the two batching regions the framework treats differently:
Region |
Sized at… |
Traced as… |
Broadcasts with… |
Typical use |
|---|---|---|---|---|
Dynamic batch |
call time |
dynamic dim |
everything |
every “ordinary” batch — N material points, time steps. |
Sub-batch |
construction time |
static shape |
other sub-batches of matching width |
per-site structure — interpolation-table axis, FV cell index, slip-system axis. |
Operationally:
Default
sub_batch_ndim = 0. Models that don’t need the distinction ignore it; everything sits in the dynamic region.Promote axes to sub-batch with
.sub_batch.retag(n)(n= number of trailing batch axes to mark). The most common case isn = 1marking a lookup-table or per-cell axis.Sub-batch dims do NOT participate in dynamic-batch broadcasting. They behave like a small extra structural region the chain-rule machinery accumulates over.
A representative use, from the Kocks–Mecking shear-modulus lookup table:
from neml2.types import Scalar
import torch
T_controls = Scalar.linspace(300.0, 1200.0, 20).sub_batch.retag(1)
This marks the trailing length-20 axis as the sub-batch (interpolation
control points), so any model consuming T_controls accumulates its
chain-rule contribution across that axis without conflating it with
the dynamic per-state batch.
.sub_batch.retag(n) is also accepted inside a [Tensors] HIT block:
[Tensors]
[T_controls]
type = Python
expr = 'Scalar.linspace(300.0, 1200.0, 20).sub_batch.retag(1)'
[]
[]
Region views¶
Shape-manipulation methods live on four region-view properties so the intent of any reshape, broadcast, or reduction is unambiguous:
t.batch— the combineddynamic_batch + sub_batchregion. Read-only.shape/.ndim; shape-changing ops raise so callers pickdynamic_batchorsub_batchexplicitly. The free functioncatinneml2.typesaccepts a batch view if you do need to concatenate across the combined region.t.dynamic_batch— dynamic batch only. Ops preservesub_batch_ndim.t.sub_batch— sub-batch only. Ops adjustsub_batch_ndim.t.base— the base region. Read-only except fortransposeon the square-base types (R2,SSR4).
Every mutable view exposes the same surface: .shape, .ndim,
.unsqueeze(dim), .squeeze(dim), .expand(*shape). Concatenation
along a region axis goes through the free function cat in
neml2.types (see below).
The view methods return a fresh wrapper, so calls chain cleanly:
broadcast = SR2.fill(0.1, -0.05, -0.05, 0, 0, 0).dynamic_batch.expand(20)
# Construct an SR2 of base shape (6,), then broadcast it to (20, 6).
retagged = Scalar.linspace(0, 1, 5).sub_batch.retag(1)
# Mark the trailing length-5 axis as sub-batch.
tr_R = R.base.transpose(-2, -1) # Transpose the (3, 3) base of an R2.
The companion free functions sum, mean, diff in neml2.types
take a view argument and dispatch on its kind:
from neml2.types import sum, mean, diff
avg = mean(t.sub_batch, dim=0) # Reduce sub-batch axis 0.
total = sum(t.dynamic_batch, [0, 1]) # Reduce two dynamic axes.
delta = diff(t.sub_batch, n=1, dim=-1)
Construction surface¶
Beyond raw-tensor construction, every primitive inherits a small
factory family from PrimitiveTensor:
<T>.zeros(*batch, dtype=None, device=None)— zero-filled wrapper of dynamic shapebatchand baseT.BASE_SHAPE.<T>.ones(*batch, ...),<T>.full(*batch, fill_value=..., ...),<T>.empty(*batch, ...).<T>.fill(*components, ...)— reshapeprod(T.BASE_SHAPE)scalars into the base.SR2.filloverrides this with Mandel-aware 1 / 3 / 6 component overloads (the √2 shear scaling is internal).<T>.identity(...)where mathematically meaningful (R2,SR2,WR2,Rot,SSR4’s several projector variants).
Scalar adds the torch-analogue factories:
Scalar(<float>)/Scalar(<int>)/Scalar(<list>)— direct literal coercion, defaults totorch.float64.Scalar.zeros,Scalar.ones,Scalar.full— override thePrimitiveTensordefaults to keepfloat64.Scalar.linspace(start, end, steps),Scalar.arange(start, end, step)— mirror the torch creation API.Scalar.from_value(x, like=other_wrapper)— promote a Python literal inheritingdtype/devicefrom an existing wrapper. Useful inside leafforwardto build in-place neutrals.
Every constructor accepts an optional device= / dtype= kwarg; see
Evaluation device for the device story.
See also¶
Vectorization — the everyday user-facing view of batching with a timed loop-vs-batched comparison.
Evaluation device — moving wrappers across devices with
.to(device=...).Input files — the
[Tensors]section that constructs wrappers from HIT.