neml2.models.model

Model — base class for Python-native NEML2 models.

Most subclasses declare a class-level hit = HitSchema(...); the base class derives input_spec / output_spec (variable name → type) from its input(...) / output(...) fields and provides a default from_hit. Dynamic-I/O models may still set input_spec / output_spec directly and override from_hit. Every subclass implements forward().

Unified forward contract:

# Pure forward — returns outputs only.
out = model(*inputs)

# Forward + Jacobian pushforward —
# ``v[in_name][leaf_name]`` is a typed wrapper with leading seed axis K:
# data shape ``(K, *B, *sub, *base_in)``.
# Returns ``(*outputs, v_out_dict)`` where ``v_out_dict[out_name][leaf_name]``
# is the output's typed wrapper with the same leading K. No explicit
# Jacobian block is materialised inside the chain-rule graph.
*vals, v_out = model(*inputs, v={"name": {"leaf": sensitivity_matrix, ...}})

# Forward + first- and second-order Jacobian pushforward (opt-in) —
# ``v2[in_name][seed_a][seed_b]`` is a typed wrapper with two leading seed
# axes ``(N_a, N_b, *B, *sub, *base_in)``.
# Returns ``(*outputs, v_out, v2_out)``. Only models that may appear inside
# a Normality wrap implement v2; callers passing v2 must also pass v.
*vals, v_out, v2_out = model(*inputs, v=..., v2=...)

Variable names in input_spec / output_spec are plain strings with no hierarchical prefix, e.g. "strain", "stress", "plastic_strain".

neml2.models.model.ChainRuleAction

alias of Callable[[…], TensorWrapper]

neml2.models.model.ChainRuleDict

alias of dict[str, dict[str, TensorWrapper]]

class neml2.models.model.Model(**hit_values)[source]

Bases: Module, ABC

Base class for Python-native NEML2 models with declared variable names.

Subclasses declare input_spec and output_spec as class-level dicts (static models) or instance attributes set in __init__ (models whose output count depends on constructor arguments).

input_spec key order matches forward()’s positional argument order; output_spec key order matches the return-tuple / return-value order.

When called without v, forward returns the outputs directly (a single typed wrapper or a tuple thereof). When called with v, it additionally returns the sensitivity dict as the final element of the tuple.

SECTION: ClassVar[str] = 'Models'

HIT section every registered subclass belongs to. Inherited; subclasses that deliberately live elsewhere (none today) can override.

SUPPORTS_SECOND_ORDER: bool = False

True iff this Model’s forward accepts v2 and vh kwargs and propagates them via apply_chain_rule_2(). Default False — most leaves only implement first-order chain rule and don’t need v2. A leaf must set this to True if it may appear inside a Normality wrap (directly or transitively through a ComposedModel); Normality’s constructor walks the inner chain and raises if any leaf has this flag unset.

Plain attribute (not ClassVar) so ComposedModel can shadow it on instances based on whether all its children support v2.

Type:

Opt-in flag

apply_chain_rule(v, output_name, actions, *, output=None)[source]

Apply local chain-rule actions and accumulate by seed leaf.

actions maps each input variable name to a function that transforms an incoming sensitivity block for that input into its contribution to output_name. Missing input/leaf sensitivities are structural zeros.

When two actions contribute to the same seed leaf but with different sub-batch structure — typical of a per-crystal output that mixes a global input (e.g. d, w) and a per-crystal input (e.g. dp, e) — the accumulator pads the lower-ndim contribution with singleton axes at the start of its sub-batch region so the sum broadcasts correctly. This mirrors the C++ chain_rule helper’s du_dx_f.intmd_unsqueeze(...) step (src/neml2/tensors/functions/chain_rule.cxx).

When output (the leaf’s forward result wrapper) is supplied, each accumulated contribution is retagged with output.sub_batch_ndim so the action body never has to declare sub-batch metadata explicitly. This is the foundational-op equivalent of the C++ side encoding sub-batch entirely in data.shape — the leaf does math, the accumulator owns the metadata.

tangents are ordinary typed wrappers with K as the leading batch dim. Seeds that arrive as raw tensors (tests / export seeding) are wrapped as the input variable’s type. Accumulation is plain typed + (align_sub_batch under the hood).

Parameters:
Return type:

dict[str, dict[str, TensorWrapper]]

apply_chain_rule_2(v, v2, output_name, actions_1, actions_2, vh=None)[source]

Propagate a second-order JVP through this leaf’s local Jacobian.

Implements $(g∘f)’’[a, b] = g’’(f) · (f’[a], f’[b]) + g’ · f’’[a, b]$:

  • The g'' term iterates input pairs (i, j) and combines incoming first-order tangents v[i][a] (slot 1) and vh[j][b] (slot 2) into a two-leading-axis typed output tangent via actions_2[(i, j)]. When vh is None it defaults to v (the original symmetric all-pairs behaviour). When vh is provided, only (v[i] × vh[j]) pairs are iterated — used by Normality to compute Hessian-applied-to-outer directly without materialising the full Hessian.

  • The g' term applies the existing first-order action to incoming second-order tangents v2[i][a][b]. Re-uses actions_1 so the inner-input → outer-seed contraction matches the first-order path exactly.

Missing input pairs in actions_2 are treated as f''=0; missing entries in v / vh / v2 are structural zeros. The resulting dict carries one outer key (output_name).

Parameters:
Return type:

dict[str, dict[str, dict[str, TensorWrapper]]]

call_by_name(state)[source]

Call forward() (pure, no v) with values keyed by variable name.

Accepts typed wrappers (preferred per rule 1) or raw tensors (wrapped via the input_spec for caller convenience). Always returns typed wrappers – consumers never have to re-attach metadata.

Parameters:

state (Mapping[str, TensorWrapper | Tensor])

Return type:

dict[str, TensorWrapper]

property consumed_items: frozenset[str]
declare_typed_buffer(name, spec, type_cls, *, factory=None)[source]

Resolve spec as a constant value and register it as a typed buffer.

Buffer-flavored sibling of declare_typed_parameter(). Accepts the same literal / [Tensors]-cross-ref spec shapes (modes 1 and 2) but not the input-promotion modes (3 / 4): a buffer is a constant baked into the model, so promoting it to a chain-rule input would contradict its semantics.

Resolution order:

  1. TensorWrapper / torch.Tensor / float / int — wrap as type_cls and register via register_typed_buffer().

  2. str:

    1. Try parse as a whitespace-separated list of floats — register as a typed buffer (HIT literal).

    2. If a factory is available, try factory.get_tensor(spec) — register as a typed buffer ([Tensors] cross-ref).

Raises ValueError on any string spec that resolves to neither a literal nor a [Tensors] entry.

Parameters:
Return type:

None

declare_typed_parameter(name, spec, type_cls, *, factory=None, allow_nonlinear=False)[source]

Resolve spec and register it as a parameter or promote it to an input.

Python mirror of C++ ParameterStore::declare_parameter (src/neml2/models/ParameterStore.cxx). Resolution order:

  1. TensorWrapper / torch.Tensor / float / int — wrap as type_cls and call register_typed_parameter() (mode 1 with an already-loaded literal/batched value).

  2. str:

    1. Try parse as a whitespace-separated list of floats — register as a typed parameter (mode 1, literal HIT value).

    2. If a factory is available, try factory.get_tensor(spec) — register as a typed parameter (mode 2, [Tensors] cross-ref).

    3. If allow_nonlinear: parse the string as a variable specifier (model_name / model_name.var / var). If it matches a [Models] entry, pull the provider and record the input promotion + provider in _nl_params (mode 3). Otherwise treat the string as a bare variable name and add the input without a provider (mode 4).

The host model’s input_spec is extended in modes 3 + 4 with an entry keyed by the chosen input variable name (the provider’s output name in mode 3, or the bare variable name in mode 4), appended after the fixed structural inputs. Inside forward — declared as forward(self, <structural inputs...>, *nl_params) — fetch the value with _get_param(), which resolves a static slot from self or a promoted slot from the *nl_params pack uniformly.

Parameters:
Return type:

None

classmethod from_hit(node, factory)[source]

Construct this model from its declarative HitSchema.

Every schema field — options, dependencies, input/output renames, var_inputs lists, and parameters — flows through _store_schema_values during construction; a leaf whose only state is its schema needs no __init__ at all. Models with dynamic I/O or non-trivial construction logic may still override this method.

Parameters:
  • node (nmhit.Node)

  • factory (_NativeInputFile)

Return type:

Any

hit: HitSchema
input_spec: dict[str, type[TensorWrapper]]
output_priorities: dict[str, str | None] = {}

HIT-bound output name → priority claim ("high" / "low" / None) sourced from each output() field’s priority= kwarg. The DependencyResolver reads this to lift the duplicate-provider error when sibling models provide the same name with disambiguating priorities, and to add low default high ordering edges so the highest-priority writer runs last. Names absent from the dict default to None.

output_spec: dict[str, type[TensorWrapper]]
propagate_tangents(v, output_name, actions_1, *, output=None, v2=None, actions_2=None, vh=None)[source]

Dispatch v / v2 / vh through the local chain-rule actions.

Wraps the boilerplate every second-order-aware leaf otherwise has to spell out: call apply_chain_rule() for v (always), apply_chain_rule_2() for v2 / vh (when requested), and return the right-length tuple. The return shape mirrors what the leaf was asked for:

  • v2 is None and vh is None(v_out,)

  • v2 is set, vh is None(v_out, v2_out)

  • vh is set (v2 may be None, treated as {}) → (v_out, v2_out, vh_out)

Linear leaves (LinearCombination, YieldFunction, …) call this with no actions_2 — the second-order pass collapses to applying actions_1 to v2 entries (g'' = 0). Non-linear leaves (SR2Invariant, …) pass an explicit actions_2 map.

Usage:

return out, *self.propagate_tangents(
    v, self._to, actions_1, output=out, v2=v2, vh=vh
)
Parameters:
Return type:

tuple

property provided_items: frozenset[str]
register_typed_buffer(name, value, persistent=True)[source]

Register a typed tensor buffer (no autograd; baked as a constant by AOTI export).

Parameters:
Return type:

None

register_typed_parameter(name, value)[source]

Register a typed tensor as a calibration-tracked nn.Parameter.

Mirrors register_typed_buffer() but stores via nn.Module.register_parameter(), so the value appears in model.parameters() and PyTorch autograd flows through it in eager mode. AOTI export converts these back to constants before tracing (see aoti_export._freeze_parameters_to_buffers); the forward-only AOTI graph is unchanged.

Parameters:
Return type:

None

class neml2.models.model.NLParam(input_name, tail_index, provider=None, provider_output=None)[source]

Bases: object

Marker for a parameter resolved to a runtime input (modes 3 + 4).

Records the input variable name added to the host’s input_spec, the parameter’s position within forward’s *nl_params pack, and — for mode 3 — the provider model + its output variable name so the parent ComposedModel can auto-pull the provider into the dependency graph (mirroring the C++ _nl_params bookkeeping in ParameterStore.cxx::resolve_tensor_name).

tail_index is the zero-based slot of this parameter inside the *nl_params pack passed to Model._get_param(). Promoted parameters are appended to input_spec in declaration order, immediately after the fixed structural inputs, so this index is simply the number of parameters already promoted when this one was declared.

For mode 4 (no provider — pure input promotion), provider is None.

Parameters:
  • input_name (str)

  • tail_index (int)

  • provider (Model | None)

  • provider_output (str | None)

input_name: str
provider: Model | None = None
provider_output: str | None = None
tail_index: int
neml2.models.model.SecondOrderChainRuleAction

alias of Callable[[…], TensorWrapper]

neml2.models.model.SecondOrderChainRuleDict

alias of dict[str, dict[str, dict[str, TensorWrapper]]]

neml2.models.model.TangentAction

alias of TensorWrapper

neml2.models.model.register_submodule(parent, child, fallback, *, used=None)[source]

Add child to parent under its HIT block name if available.

The factory stamps _hit_name on every object it constructs; preferring that name over an opaque attribute slot keeps named_parameters() readable (elasticity.E instead of _residual_model.E). Falls back to fallback when the HIT name is missing (direct Python construction), is not a valid Python identifier, would collide with an existing attribute on parent, or is already in used (when a parent registers several children in one pass and must avoid collisions across siblings).

Returns the attribute name the child was registered under.

Parameters:
Return type:

str