NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
Tensor types

NEML2 tensors are extremely similar to the tensor type provided by its default backend ATen. In fact, all NEML2 tensors inherit from at::Tensor, the tensor type from ATen. What's different from ATen is that the NEML2 tensor library provides the following enhancements:

  • Explicit distinction between batch and base dimensions
  • Primitive tensor types commonly used in traditional scientific computing
  • Commonly used math operators and functions for primitive tensor types

Batching

Given a general, n-dimensional tensor (or array), its shape is usually denoted using a tuple consisting of sizes of each of its dimension, i.e.

(d0,d1,d2,...,di,...,dn2,dn1),i=0,1,2,...,n1,

where di is the number of components along dimension i.

In NEML2, we explicitly introduce the concept of batching such that the shape of a n-dimensional tensor may now be denoted (with a slight abuse of notation) as

(c0,c1,c2,...,ci,...,cnb2,cnb1;d0,d1,d2,...,dj,...,dnnb2,dnnb1),i=0,1,2,...,nb1,j=0,1,2,...,nnb1,

where nb[0,n] is the number of batch dimensions, ci is the number of components along each batch dimension, and dj is the number of components along each base dimension. In NEML2's notation, a separator ; is used to delimit batch and base sizes. Every batched tensor must have one and only one ; delimiter.

For example, given a tensor of shape

(3,100,5,13,2),

we could assign it a batch dimension of 3, and the resulting shape of the batched tensor becomes

(3,100,5;13,2).

Similarly, if the batch dimension is 0 or 5, the resulting shape would become

(;3,100,5,13,2),(3,100,5,13,2;).

It is valid for the ; delimiter to appear at the beginning or the end of the shape tuple. If it appears at the beginning it means that tensor has no batch dimensions. If it appears at the end it means the tensor is a batched scalar.

Why batching?

The definition of batching is fairly straightforward. However, fair question to ask is "why another layer of abstraction on top of the regular tensor shape definition"? Let's briefly consider the high-level motivation behind this design choice.

Shape ambiguity: Consider a tensor of shape (55,100,3,3,3,3), it can be interpreted as

  • a scalar with batch shape (55,100,3,3,3,3),
  • a 3-vector with batch shape (55,100,3,3,3),
  • a second order tensor with batch shape (55,100,3,3),
  • a third order tensor with batch shape (55,100,3), or
  • a fourth order tensor with batch shape (55,100).

Such ambiguity can be avoided if the user consistently keeps track of batching throughout the lifetime of all related tensor operations, which is manageable in simple models but difficult to scale if the model grows in complexity. Therefore, in NEML2, we incorporate the management and propagation of this information directly in the tensor library, removing such overhead from user implementation.

Scalar broadcasting: With more details to be covered in a later tutorial on broadcasting, explicit batching allows general broadcasting rules between scalar and other primitive tensor types.

Shape generalization: As mentioned in the tutorial on JIT compilation, NEML2 supports tracing tensor operations into graphs. The PyTorch tracing engine does not support tracing operations on tensor shapes, and so users are left with two options:

  • Trace the graph with example input variables having the same shape as real input variables to be used later on.
  • Re-trace the graph whenever any input variable has a different shape.

The first option implies a potential memory bound at the stage of JIT compilation, and is difficult to guarantee in practice. The second option imposes a significant runtime overhead whenever re-tracing is necessary.

With explicit batching, NEML2 is able to generate batch-generalizable graphs. As long as the batch dimensions of the input variables remain unchanged, the same traced graph can be reused without re-tracing.

Tensor types

neml2::Tensor is the general-purpose tensor type whose base shape can be modified at runtime.

In addition, NEML2 offers a rich collection of primitive tensor types whose base shapes are static, i.e., remain fixed at runtime. Currently implemented tensor types are summarized in the following table.

Tensor type Base shape Description
Tensor Dynamic General-purpose tensor with dynamic base shape
Scalar () Rank-0 tensor, i.e. scalar
Vec (3) Rank-1 tensor, i.e. vector
R2 (3,3) Rank-2 tensor
SR2 (6) Symmetric rank-2 tensor
WR2 (3) Skew-symmetric rank-2 tensor
R3 (3,3,3) Rank-3 tensor
SFR3 (6,3) Rank-3 tensor with symmetry on base dimensions 0 and 1
R4 (3,3,3,3) Rank-4 tensor
SFR4 (6,3,3) Rank-4 tensor with symmetry on base dimensions 0 and 1
WFR4 (3,3,3) Rank-4 tensor with skew symmetry on base dimensions 0 and 1
SSR4 (6,6) Rank-4 tensor with minor symmetry
SWR4 (6,3) Rank-4 tensor with minor symmetry then skew symmetry
WSR4 (3,6) Rank-4 tensor with skew symmetry then minor symmetry
WWR4 (3,3) Rank-4 tensor with skew symmetry
R5 (3,3,3,3,3) Rank-5 tensor
SSFR5 (6,6,3) Rank-5 tensor with minor symmetry on base dimensions 0-3
R8 (3,3,3,3,3,3,3,3) Rank-8 tensor
SSSSR8 (6,6,6,6) Rank-8 tensor with minor symmetry
Rot (3) Rotation tensor represented in the Rodrigues form
Quaternion (4) Quaternion
MillerIndex (3) Crystal direction or lattice plane represented as Miller indices

All primitive tensor types can be declared as variables, parameters, and buffers in a model.

Previous Next
Tensors Tensor creation