neml2.models.chain_rule

Shared type aliases for chain-rule sensitivity propagation.

Tangents in transit through the framework are ordinary typed wrappers (Scalar / SR2 / R2 / …) carrying one or more leading K (seed-direction) axes – a tangent of type T is a T, with k_ndim > 0 and k_state / k_pairing recording the per-K-axis storage layout (see neml2.types._base.TensorWrapper).

Per the v2-parity refactor, the declarative EdgeInfo / list_deriv machinery has been removed. The chain rule no longer dispatches on labels – it relies on positional k_pairing metadata to know when a K axis is paired with a sub_batch axis (cheap broadcast diagonal) vs when it has been exposed to per-site enumeration.

When a leaf’s action would contract a paired sub_batch axis (e.g. a sum / mean / inner that mixes per-site data), it must call fullify from neml2.types.functions first. The exposing reductions (sum_sub_batch, sum / mean over a K-paired axis) handle the common case automatically.

neml2.models.chain_rule.ChainRuleAction

alias of Callable[[…], TensorWrapper]

neml2.models.chain_rule.ChainRuleDict

alias of dict[str, dict[str, TensorWrapper]]

neml2.models.chain_rule.SecondOrderChainRuleAction

alias of Callable[[…], TensorWrapper]

neml2.models.chain_rule.SecondOrderChainRuleDict

alias of dict[str, dict[str, dict[str, TensorWrapper]]]

neml2.models.chain_rule.SecondOrderTangentAction

alias of TensorWrapper

neml2.models.chain_rule.TangentAction

alias of TensorWrapper

neml2.models.chain_rule.equalize_tangent_K(contributions)[source]

Tile compact-K contributions to match the max-K contribution.

Apply-chain-rule accumulates contributions for the same seed leaf across multiple input edges. Different edges can produce different K storage widths – one edge may carry a tangent whose K axes have been exposed (full) by a reducing op, while a parallel edge with no sub_batch interaction carries an un-exposed (broadcast) tangent with the same K_ndim but smaller storage.

Summing tensors of different storage K widths along the leading K axes is shape-incompatible. The fix is to tile the lower-K contributions to the common max along each K axis. A compact contribution represents a response that does not depend on the per-site index (definitionally – it came from a chain with no sub_batch interaction), so tiling preserves semantics.

Returns a list of contributions all with the same leading K storage shape. No-op when every contribution already matches the max (the common case once align_k + combine_k_state have already aligned).

Parameters:

contributions (list[TensorWrapper])

Return type:

list[TensorWrapper]

neml2.models.chain_rule.matvec(M, v)[source]

Fused matrix-vector product without einsum.

M @ v where the trailing dims are (..., n_row, n_col) and (..., n_col) respectively. Implemented as (M @ v.unsqueeze(-1)).squeeze(-1) so Inductor can fuse the trailing pointwise ops around the matmul.

Parameters:
Return type:

Tensor