NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBase< Derived > Class Template Reference

NEML2's enhanced tensor type. More...

Detailed Description

template<class Derived>
class neml2::TensorBase< Derived >

NEML2's enhanced tensor type.

neml2::TensorBase derives from ATensor and clearly distinguishes between "batched" dimensions from other dimensions.

A tensor shape can be pictorially labeled as below

(b1, b2, b3, b4, b5, b6, b7, ...; d1, d2, d3, d4, d5, ...)
|_____________________________| |_____________________|
batch base
|____________| |_____________|
dynamic intermediate
|______________________________________|
static
  • batch: The dimensions over which tensor operations should broadcast, i.e., the same operation should be applied across all batches.
  • base: The dimensions that are statically (at compile-time) determined by the tensor type.
  • dynamic: The dimensions whose sizes are dynamic in a traced graph of tensor operations, i.e., the traced graph can be generalized to tensors with different dynamic sizes.
  • static: The dimensions whose sizes are fixed in a traced graph of tensor operations, i.e., the sizes of these dimensions are hard-coded in the traced graph.
  • intermediate: Batch dimensions that are static, i.e., their sizes are fixed in the traced graph but tensor operations are batched over them.
Note
By default, and in most cases, the number of intermediate dimensions is zero, meaning that "batch" == "dynamic" and "base" == "static".

#include <TensorBase.h>

Inheritance diagram for TensorBase< Derived >:

Public Member Functions

 TensorBase ()=default
 Default constructor.
 
 TensorBase (const ATensor &tensor, Size dynamic_dim, Size intmd_dim)
 Construct from an ATensor with given dynamic dimension.
 
 TensorBase (const ATensor &tensor, TraceableTensorShape dynamic_shape, Size intmd_dim)
 Construct from an ATensor with given dynamic shape.
 
template<class Derived2 >
 TensorBase (const TensorBase< Derived2 > &tensor)
 Copy constructor.
 
 TensorBase (double)=delete
 
 TensorBase (float)=delete
 
 TensorBase (int)=delete
 
Derived variable_data () const
 Variable data without function graph.
 
Meta operations
Derived contiguous () const
 
Derived clone () const
 Clone (take ownership)
 
Derived detach () const
 Discard function graph.
 
Derived to (const TensorOptions &options) const
 Change tensor options.
 
Derived operator- () const
 Negation.
 
Size batch_dim () const
 
Size base_dim () const
 
Size dynamic_dim () const
 
Size static_dim () const
 
Size intmd_dim () const
 
TraceableTensorShape batch_sizes () const
 
TensorShapeRef base_sizes () const
 
const TraceableTensorShapedynamic_sizes () const
 
TensorShapeRef static_sizes () const
 
TensorShapeRef intmd_sizes () const
 
TraceableSize batch_size (Size i) const
 
Size base_size (Size i) const
 
const TraceableSizedynamic_size (Size i) const
 
Size static_size (Size i) const
 
Size intmd_size (Size i) const
 
Derived dynamic_index (indexing::TensorIndicesRef indices) const
 
Derived intmd_index (indexing::TensorIndicesRef indices) const
 
neml2::Tensor base_index (indexing::TensorIndicesRef indices) const
 
Derived dynamic_slice (Size d, const indexing::Slice &index) const
 
Derived intmd_slice (Size d, const indexing::Slice &index) const
 
neml2::Tensor base_slice (Size d, const indexing::Slice &index) const
 
void dynamic_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other)
 
void dynamic_index_put_ (indexing::TensorIndicesRef indices, const CScalar &v)
 
void intmd_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other)
 
void intmd_index_put_ (indexing::TensorIndicesRef indices, const CScalar &v)
 
void base_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other)
 
void base_index_put_ (indexing::TensorIndicesRef indices, const CScalar &v)
 
Derived dynamic_expand (const TraceableTensorShape &shape) const
 
Derived intmd_expand (TensorShapeRef shape) const
 
neml2::Tensor base_expand (TensorShapeRef shape) const
 
Derived batch_expand (const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
 
neml2::Tensor static_expand (TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
 
Derived dynamic_expand (const TraceableSize &size, Size d) const
 
Derived intmd_expand (Size size, Size d) const
 
neml2::Tensor base_expand (Size size, Size d) const
 
Derived dynamic_expand_as (const neml2::Tensor &other) const
 
Derived intmd_expand_as (const neml2::Tensor &other) const
 
neml2::Tensor base_expand_as (const neml2::Tensor &other) const
 
Derived batch_expand_as (const neml2::Tensor &other) const
 
neml2::Tensor static_expand_as (const neml2::Tensor &other) const
 
Derived dynamic_reshape (const TraceableTensorShape &shape) const
 
Derived intmd_reshape (TensorShapeRef shape) const
 
neml2::Tensor base_reshape (TensorShapeRef shape) const
 
Derived batch_reshape (const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
 
neml2::Tensor static_reshape (TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
 
Derived dynamic_squeeze (Size d) const
 
Derived intmd_squeeze (Size d) const
 
neml2::Tensor base_squeeze (Size d) const
 
Derived dynamic_unsqueeze (Size d, Size n=1) const
 
Derived intmd_unsqueeze (Size d, Size n=1) const
 
neml2::Tensor base_unsqueeze (Size d, Size n=1) const
 
Derived dynamic_transpose (Size d1, Size d2) const
 
Derived intmd_transpose (Size d1, Size d2) const
 
neml2::Tensor base_transpose (Size d1, Size d2) const
 
Derived dynamic_movedim (Size old_dim, Size new_dim) const
 
Derived intmd_movedim (Size old_dim, Size new_dim) const
 
neml2::Tensor base_movedim (Size old_dim, Size new_dim) const
 
Derived dynamic_flatten () const
 
Derived intmd_flatten () const
 
neml2::Tensor base_flatten () const
 
Derived batch_flatten () const
 Flatten batch dimensions.
 
neml2::Tensor static_flatten () const
 Flatten static dimensions.
 

Static Public Member Functions

static Derived empty_like (const Derived &other)
 
static Derived zeros_like (const Derived &other)
 Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived ones_like (const Derived &other)
 Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived full_like (const Derived &other, const CScalar &init)
 
static Derived rand_like (const Derived &other)
 

Protected Member Functions

void validate_shapes_and_dims () const
 Validate shapes and dimensions.
 

Constructor & Destructor Documentation

◆ TensorBase() [1/7]

template<class Derived >
TensorBase ( )
default

Default constructor.

◆ TensorBase() [2/7]

template<class Derived >
TensorBase ( const ATensor & tensor,
Size dynamic_dim,
Size intmd_dim )

Construct from an ATensor with given dynamic dimension.

◆ TensorBase() [3/7]

template<class Derived >
TensorBase ( const ATensor & tensor,
TraceableTensorShape dynamic_shape,
Size intmd_dim )

Construct from an ATensor with given dynamic shape.

◆ TensorBase() [4/7]

template<class Derived >
template<class Derived2 >
TensorBase ( const TensorBase< Derived2 > & tensor)
inline

Copy constructor.

◆ TensorBase() [5/7]

template<class Derived >
TensorBase ( double )
delete

◆ TensorBase() [6/7]

template<class Derived >
TensorBase ( float )
delete

◆ TensorBase() [7/7]

template<class Derived >
TensorBase ( int )
delete

Member Function Documentation

◆ base_dim()

template<class Derived >
Size base_dim ( ) const

◆ base_expand() [1/2]

template<class Derived >
neml2::Tensor base_expand ( Size size,
Size d ) const

◆ base_expand() [2/2]

template<class Derived >
neml2::Tensor base_expand ( TensorShapeRef shape) const

◆ base_expand_as()

template<class Derived >
neml2::Tensor base_expand_as ( const neml2::Tensor & other) const

◆ base_flatten()

template<class Derived >
neml2::Tensor base_flatten ( ) const

◆ base_index()

template<class Derived >
neml2::Tensor base_index ( indexing::TensorIndicesRef indices) const

◆ base_index_put_() [1/2]

template<class Derived >
void base_index_put_ ( indexing::TensorIndicesRef indices,
const ATensor & other )

◆ base_index_put_() [2/2]

template<class Derived >
void base_index_put_ ( indexing::TensorIndicesRef indices,
const CScalar & v )

◆ base_movedim()

template<class Derived >
neml2::Tensor base_movedim ( Size old_dim,
Size new_dim ) const

◆ base_reshape()

template<class Derived >
neml2::Tensor base_reshape ( TensorShapeRef shape) const

◆ base_size()

template<class Derived >
Size base_size ( Size i) const

◆ base_sizes()

template<class Derived >
TensorShapeRef base_sizes ( ) const

◆ base_slice()

template<class Derived >
neml2::Tensor base_slice ( Size d,
const indexing::Slice & index ) const

◆ base_squeeze()

template<class Derived >
neml2::Tensor base_squeeze ( Size d) const

◆ base_transpose()

template<class Derived >
neml2::Tensor base_transpose ( Size d1,
Size d2 ) const

◆ base_unsqueeze()

template<class Derived >
neml2::Tensor base_unsqueeze ( Size d,
Size n = 1 ) const

◆ batch_dim()

template<class Derived >
Size batch_dim ( ) const

◆ batch_expand()

template<class Derived >
Derived batch_expand ( const TraceableTensorShape & dynamic_shape,
TensorShapeRef intmd_shape ) const

◆ batch_expand_as()

template<class Derived >
Derived batch_expand_as ( const neml2::Tensor & other) const

◆ batch_flatten()

template<class Derived >
Derived batch_flatten ( ) const

Flatten batch dimensions.

Note
All intermediate dimensions and dynamic dimensions get flattened into one single batch dimension. In other words, the resulting tensor will have ZERO intermediate dimensions.

◆ batch_reshape()

template<class Derived >
Derived batch_reshape ( const TraceableTensorShape & dynamic_shape,
TensorShapeRef intmd_shape ) const

◆ batch_size()

template<class Derived >
TraceableSize batch_size ( Size i) const

◆ batch_sizes()

template<class Derived >
TraceableTensorShape batch_sizes ( ) const

◆ clone()

template<class Derived >
Derived clone ( ) const

Clone (take ownership)

◆ contiguous()

template<class Derived >
Derived contiguous ( ) const

Make contiguous

◆ detach()

template<class Derived >
Derived detach ( ) const

Discard function graph.

◆ dynamic_dim()

template<class Derived >
Size dynamic_dim ( ) const

◆ dynamic_expand() [1/2]

template<class Derived >
Derived dynamic_expand ( const TraceableSize & size,
Size d ) const

Return a view of the tensor with values broadcast along the given dimension.

◆ dynamic_expand() [2/2]

template<class Derived >
Derived dynamic_expand ( const TraceableTensorShape & shape) const

Return a view of the tensor with values broadcast along multiple dimensions.

◆ dynamic_expand_as()

template<class Derived >
Derived dynamic_expand_as ( const neml2::Tensor & other) const

Expand the dimensions to have the same shape as another tensor's

◆ dynamic_flatten()

template<class Derived >
Derived dynamic_flatten ( ) const

Flatten a dimension group

◆ dynamic_index()

template<class Derived >
Derived dynamic_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing along multiple dimensions

◆ dynamic_index_put_() [1/2]

template<class Derived >
void dynamic_index_put_ ( indexing::TensorIndicesRef indices,
const ATensor & other )

Set values by slicing along multiple dimensions

◆ dynamic_index_put_() [2/2]

template<class Derived >
void dynamic_index_put_ ( indexing::TensorIndicesRef indices,
const CScalar & v )

◆ dynamic_movedim()

template<class Derived >
Derived dynamic_movedim ( Size old_dim,
Size new_dim ) const

} Move a dimension to a new position

◆ dynamic_reshape()

template<class Derived >
Derived dynamic_reshape ( const TraceableTensorShape & shape) const

Reshape the dimension group

◆ dynamic_size()

template<class Derived >
const TraceableSize & dynamic_size ( Size i) const

◆ dynamic_sizes()

template<class Derived >
const TraceableTensorShape & dynamic_sizes ( ) const

◆ dynamic_slice()

template<class Derived >
Derived dynamic_slice ( Size d,
const indexing::Slice & index ) const

Get a tensor by slicing along a dimension

◆ dynamic_squeeze()

template<class Derived >
Derived dynamic_squeeze ( Size d) const

Squeeze a dimension

◆ dynamic_transpose()

template<class Derived >
Derived dynamic_transpose ( Size d1,
Size d2 ) const

Transpose two dimensions

◆ dynamic_unsqueeze()

template<class Derived >
Derived dynamic_unsqueeze ( Size d,
Size n = 1 ) const

Unsqueeze n dimensions at d

◆ intmd_dim()

template<class Derived >
Size intmd_dim ( ) const

◆ intmd_expand() [1/2]

template<class Derived >
Derived intmd_expand ( Size size,
Size d ) const

◆ intmd_expand() [2/2]

template<class Derived >
Derived intmd_expand ( TensorShapeRef shape) const

◆ intmd_expand_as()

template<class Derived >
Derived intmd_expand_as ( const neml2::Tensor & other) const

◆ intmd_flatten()

template<class Derived >
Derived intmd_flatten ( ) const

◆ intmd_index()

template<class Derived >
Derived intmd_index ( indexing::TensorIndicesRef indices) const

◆ intmd_index_put_() [1/2]

template<class Derived >
void intmd_index_put_ ( indexing::TensorIndicesRef indices,
const ATensor & other )

◆ intmd_index_put_() [2/2]

template<class Derived >
void intmd_index_put_ ( indexing::TensorIndicesRef indices,
const CScalar & v )

◆ intmd_movedim()

template<class Derived >
Derived intmd_movedim ( Size old_dim,
Size new_dim ) const

◆ intmd_reshape()

template<class Derived >
Derived intmd_reshape ( TensorShapeRef shape) const

◆ intmd_size()

template<class Derived >
Size intmd_size ( Size i) const

◆ intmd_sizes()

template<class Derived >
TensorShapeRef intmd_sizes ( ) const

◆ intmd_slice()

template<class Derived >
Derived intmd_slice ( Size d,
const indexing::Slice & index ) const

◆ intmd_squeeze()

template<class Derived >
Derived intmd_squeeze ( Size d) const

◆ intmd_transpose()

template<class Derived >
Derived intmd_transpose ( Size d1,
Size d2 ) const

◆ intmd_unsqueeze()

template<class Derived >
Derived intmd_unsqueeze ( Size d,
Size n = 1 ) const

◆ operator-()

template<class Derived >
Derived operator- ( ) const

Negation.

◆ static_dim()

template<class Derived >
Size static_dim ( ) const

◆ static_expand()

template<class Derived >
neml2::Tensor static_expand ( TensorShapeRef intmd_shape,
TensorShapeRef base_shape ) const

◆ static_expand_as()

template<class Derived >
neml2::Tensor static_expand_as ( const neml2::Tensor & other) const

◆ static_flatten()

template<class Derived >
neml2::Tensor static_flatten ( ) const

Flatten static dimensions.

Note
All intermediate dimensions and base dimensions get flattened into one single base dimension. In other words, the resulting tensor will have ZERO intermediate dimensions.

◆ static_reshape()

template<class Derived >
neml2::Tensor static_reshape ( TensorShapeRef intmd_shape,
TensorShapeRef base_shape ) const

◆ static_size()

template<class Derived >
Size static_size ( Size i) const

◆ static_sizes()

template<class Derived >
TensorShapeRef static_sizes ( ) const

◆ to()

template<class Derived >
Derived to ( const TensorOptions & options) const

Change tensor options.

◆ validate_shapes_and_dims()

template<class Derived >
void validate_shapes_and_dims ( ) const
protected

Validate shapes and dimensions.

◆ variable_data()

template<class Derived >
Derived variable_data ( ) const

Variable data without function graph.