template<class Derived>
class neml2::TensorBase< Derived >
NEML2's enhanced tensor type.
neml2::TensorBase derives from ATensor and clearly distinguishes between "batched" dimensions from other dimensions.
A tensor shape can be pictorially labeled as below
(b1, b2, b3, b4, b5, b6, b7, ...; d1, d2, d3, d4, d5, ...)
|_____________________________| |_____________________|
batch base
|____________| |_____________|
dynamic intermediate
|______________________________________|
static
- batch: The dimensions over which tensor operations should broadcast, i.e., the same operation should be applied across all batches.
- base: The dimensions that are statically (at compile-time) determined by the tensor type.
- dynamic: The dimensions whose sizes are dynamic in a traced graph of tensor operations, i.e., the traced graph can be generalized to tensors with different dynamic sizes.
- static: The dimensions whose sizes are fixed in a traced graph of tensor operations, i.e., the sizes of these dimensions are hard-coded in the traced graph.
- intermediate: Batch dimensions that are static, i.e., their sizes are fixed in the traced graph but tensor operations are batched over them.
- Note
- By default, and in most cases, the number of intermediate dimensions is zero, meaning that "batch" == "dynamic" and "base" == "static".
|
| | TensorBase ()=default |
| | Default constructor.
|
| |
| | TensorBase (const ATensor &tensor, Size dynamic_dim, Size intmd_dim) |
| | Construct from an ATensor with given dynamic dimension.
|
| |
| | TensorBase (const ATensor &tensor, TraceableTensorShape dynamic_shape, Size intmd_dim) |
| | Construct from an ATensor with given dynamic shape.
|
| |
| template<class Derived2 > |
| | TensorBase (const TensorBase< Derived2 > &tensor) |
| | Copy constructor.
|
| |
| | TensorBase (double)=delete |
| |
| | TensorBase (float)=delete |
| |
| | TensorBase (int)=delete |
| |
| Derived | variable_data () const |
| | Variable data without function graph.
|
| |
|
| Derived | contiguous () const |
| |
| Derived | clone () const |
| | Clone (take ownership)
|
| |
| Derived | detach () const |
| | Discard function graph.
|
| |
| Derived | to (const TensorOptions &options) const |
| | Change tensor options.
|
| |
| Derived | operator- () const |
| | Negation.
|
| |
|
| Size | batch_dim () const |
| |
| Size | base_dim () const |
| |
| Size | dynamic_dim () const |
| |
| Size | static_dim () const |
| |
| Size | intmd_dim () const |
| |
|
| TraceableTensorShape | batch_sizes () const |
| |
| TensorShapeRef | base_sizes () const |
| |
| const TraceableTensorShape & | dynamic_sizes () const |
| |
| TensorShapeRef | static_sizes () const |
| |
| TensorShapeRef | intmd_sizes () const |
| |
|
| TraceableSize | batch_size (Size i) const |
| |
| Size | base_size (Size i) const |
| |
| const TraceableSize & | dynamic_size (Size i) const |
| |
| Size | static_size (Size i) const |
| |
| Size | intmd_size (Size i) const |
| |
|
| Derived | dynamic_index (indexing::TensorIndicesRef indices) const |
| |
| Derived | intmd_index (indexing::TensorIndicesRef indices) const |
| |
| neml2::Tensor | base_index (indexing::TensorIndicesRef indices) const |
| |
|
| Derived | dynamic_slice (Size d, const indexing::Slice &index) const |
| |
| Derived | intmd_slice (Size d, const indexing::Slice &index) const |
| |
| neml2::Tensor | base_slice (Size d, const indexing::Slice &index) const |
| |
|
| void | dynamic_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other) |
| |
| void | dynamic_index_put_ (indexing::TensorIndicesRef indices, const CScalar &v) |
| |
| void | intmd_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other) |
| |
| void | intmd_index_put_ (indexing::TensorIndicesRef indices, const CScalar &v) |
| |
| void | base_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other) |
| |
| void | base_index_put_ (indexing::TensorIndicesRef indices, const CScalar &v) |
| |
|
| Derived | dynamic_expand (const TraceableTensorShape &shape) const |
| |
| Derived | intmd_expand (TensorShapeRef shape) const |
| |
| neml2::Tensor | base_expand (TensorShapeRef shape) const |
| |
| Derived | batch_expand (const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const |
| |
| neml2::Tensor | static_expand (TensorShapeRef intmd_shape, TensorShapeRef base_shape) const |
| |
|
| Derived | dynamic_expand (const TraceableSize &size, Size d) const |
| |
| Derived | intmd_expand (Size size, Size d) const |
| |
| neml2::Tensor | base_expand (Size size, Size d) const |
| |
|
| Derived | dynamic_expand_as (const neml2::Tensor &other) const |
| |
| Derived | intmd_expand_as (const neml2::Tensor &other) const |
| |
| neml2::Tensor | base_expand_as (const neml2::Tensor &other) const |
| |
| Derived | batch_expand_as (const neml2::Tensor &other) const |
| |
| neml2::Tensor | static_expand_as (const neml2::Tensor &other) const |
| |
|
| Derived | dynamic_reshape (const TraceableTensorShape &shape) const |
| |
| Derived | intmd_reshape (TensorShapeRef shape) const |
| |
| neml2::Tensor | base_reshape (TensorShapeRef shape) const |
| |
| Derived | batch_reshape (const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const |
| |
| neml2::Tensor | static_reshape (TensorShapeRef intmd_shape, TensorShapeRef base_shape) const |
| |
|
| Derived | dynamic_squeeze (Size d) const |
| |
| Derived | intmd_squeeze (Size d) const |
| |
| neml2::Tensor | base_squeeze (Size d) const |
| |
|
| Derived | dynamic_unsqueeze (Size d, Size n=1) const |
| |
| Derived | intmd_unsqueeze (Size d, Size n=1) const |
| |
| neml2::Tensor | base_unsqueeze (Size d, Size n=1) const |
| |
|
| Derived | dynamic_transpose (Size d1, Size d2) const |
| |
| Derived | intmd_transpose (Size d1, Size d2) const |
| |
| neml2::Tensor | base_transpose (Size d1, Size d2) const |
| |
|
| Derived | dynamic_movedim (Size old_dim, Size new_dim) const |
| |
| Derived | intmd_movedim (Size old_dim, Size new_dim) const |
| |
| neml2::Tensor | base_movedim (Size old_dim, Size new_dim) const |
| |
|
| Derived | dynamic_flatten () const |
| |
| Derived | intmd_flatten () const |
| |
| neml2::Tensor | base_flatten () const |
| |
| Derived | batch_flatten () const |
| | Flatten batch dimensions.
|
| |
| neml2::Tensor | static_flatten () const |
| | Flatten static dimensions.
|
| |
|
| static Derived | empty_like (const Derived &other) |
| |
| static Derived | zeros_like (const Derived &other) |
| | Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
|
| |
| static Derived | ones_like (const Derived &other) |
| | Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
|
| |
| static Derived | full_like (const Derived &other, const CScalar &init) |
| |
| static Derived | rand_like (const Derived &other) |
| |