neml2.models.chain_rule¶
Shared type aliases for chain-rule sensitivity propagation.
Tangents in transit through the framework are ordinary typed wrappers
(Scalar / SR2 / R2 / …) carrying one or more leading K
(seed-direction) axes – a tangent of type T is a T, with
k_ndim > 0 and k_state / k_pairing recording the per-K-axis
storage layout (see neml2.types._base.TensorWrapper).
Per the v2-parity refactor, the declarative EdgeInfo / list_deriv
machinery has been removed. The chain rule no longer dispatches on
labels – it relies on positional k_pairing metadata to know when a
K axis is paired with a sub_batch axis (cheap broadcast diagonal) vs
when it has been exposed to per-site enumeration.
When a leaf’s action would contract a paired sub_batch axis (e.g. a
sum / mean / inner that mixes per-site data), it must call
fullify from neml2.types.functions first. The exposing
reductions (sum_sub_batch, sum / mean over a K-paired axis)
handle the common case automatically.
- neml2.models.chain_rule.ChainRuleAction¶
alias of
Callable[[…],TensorWrapper]
- neml2.models.chain_rule.SecondOrderChainRuleAction¶
alias of
Callable[[…],TensorWrapper]
- neml2.models.chain_rule.SecondOrderChainRuleDict¶
- neml2.models.chain_rule.SecondOrderTangentAction¶
alias of
TensorWrapper
- neml2.models.chain_rule.TangentAction¶
alias of
TensorWrapper
- neml2.models.chain_rule.equalize_tangent_K(contributions)[source]¶
Tile compact-K contributions to match the max-K contribution.
Apply-chain-rule accumulates contributions for the same seed leaf across multiple input edges. Different edges can produce different K storage widths – one edge may carry a tangent whose K axes have been exposed (full) by a reducing op, while a parallel edge with no sub_batch interaction carries an un-exposed (broadcast) tangent with the same K_ndim but smaller storage.
Summing tensors of different storage K widths along the leading K axes is shape-incompatible. The fix is to tile the lower-K contributions to the common max along each K axis. A compact contribution represents a response that does not depend on the per-site index (definitionally – it came from a chain with no sub_batch interaction), so tiling preserves semantics.
Returns a list of contributions all with the same leading K storage shape. No-op when every contribution already matches the max (the common case once
align_k+combine_k_statehave already aligned).- Parameters:
contributions (list[TensorWrapper])
- Return type: