NEML2 2.0.0
|
NEML2 tensors are extremely similar to the tensor type provided by its default backend ATen. In fact, all NEML2 tensors inherit from at::Tensor
, the tensor type from ATen. What's different from ATen is that the NEML2 tensor library provides the following enhancements:
Given a general,
where
In NEML2, we explicitly introduce the concept of batching such that the shape of a
where
For example, given a tensor of shape
we could assign it a batch dimension of 3, and the resulting shape of the batched tensor becomes
Similarly, if the batch dimension is 0 or 5, the resulting shape would become
It is valid for the
The definition of batching is fairly straightforward. However, fair question to ask is "why another layer of abstraction on top of the regular tensor shape definition"? Let's briefly consider the high-level motivation behind this design choice.
Shape ambiguity: Consider a tensor of shape
Such ambiguity can be avoided if the user consistently keeps track of batching throughout the lifetime of all related tensor operations, which is manageable in simple models but difficult to scale if the model grows in complexity. Therefore, in NEML2, we incorporate the management and propagation of this information directly in the tensor library, removing such overhead from user implementation.
Scalar broadcasting: With more details to be covered in a later tutorial on broadcasting, explicit batching allows general broadcasting rules between scalar and other primitive tensor types.
Shape generalization: As mentioned in the tutorial on JIT compilation, NEML2 supports tracing tensor operations into graphs. The PyTorch tracing engine does not support tracing operations on tensor shapes, and so users are left with two options:
The first option implies a potential memory bound at the stage of JIT compilation, and is difficult to guarantee in practice. The second option imposes a significant runtime overhead whenever re-tracing is necessary.
With explicit batching, NEML2 is able to generate batch-generalizable graphs. As long as the batch dimensions of the input variables remain unchanged, the same traced graph can be reused without re-tracing.
neml2::Tensor is the general-purpose tensor type whose base shape can be modified at runtime.
In addition, NEML2 offers a rich collection of primitive tensor types whose base shapes are static, i.e., remain fixed at runtime. Currently implemented tensor types are summarized in the following table.
Tensor type | Base shape | Description |
---|---|---|
Tensor | Dynamic | General-purpose tensor with dynamic base shape |
Scalar | Rank-0 tensor, i.e. scalar | |
Vec | Rank-1 tensor, i.e. vector | |
R2 | Rank-2 tensor | |
SR2 | Symmetric rank-2 tensor | |
WR2 | Skew-symmetric rank-2 tensor | |
R3 | Rank-3 tensor | |
SFR3 | Rank-3 tensor with symmetry on base dimensions 0 and 1 | |
R4 | Rank-4 tensor | |
SFR4 | Rank-4 tensor with symmetry on base dimensions 0 and 1 | |
WFR4 | Rank-4 tensor with skew symmetry on base dimensions 0 and 1 | |
SSR4 | Rank-4 tensor with minor symmetry | |
SWR4 | Rank-4 tensor with minor symmetry then skew symmetry | |
WSR4 | Rank-4 tensor with skew symmetry then minor symmetry | |
WWR4 | Rank-4 tensor with skew symmetry | |
R5 | Rank-5 tensor | |
SSFR5 | Rank-5 tensor with minor symmetry on base dimensions 0-3 | |
R8 | Rank-8 tensor | |
SSSSR8 | Rank-8 tensor with minor symmetry | |
Rot | Rotation tensor represented in the Rodrigues form | |
Quaternion | Quaternion | |
MillerIndex | Crystal direction or lattice plane represented as Miller indices |
All primitive tensor types can be declared as variables, parameters, and buffers in a model.
Previous | Next |
---|---|
Tensors | Tensor creation |