NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
TensorBase< Derived > Class Template Reference

NEML2's enhanced tensor type. More...

Detailed Description

template<class Derived>
class neml2::TensorBase< Derived >

NEML2's enhanced tensor type.

neml2::TensorBase derives from ATensor and clearly distinguishes between "batched" dimensions from other dimensions. The shape of the "batched" dimensions is called the batch size, and the shape of the rest dimensions is called the base size.

#include <TensorBase.h>

Inheritance diagram for TensorBase< Derived >:

Public Member Functions

 TensorBase ()=default
 Special member functions.
 
 TensorBase (const ATensor &tensor, Size batch_dim)
 Construct from another ATensor with given batch dimension.
 
 TensorBase (const ATensor &tensor, const TraceableTensorShape &batch_shape)
 Construct from another ATensor with given batch shape.
 
 TensorBase (const neml2::Tensor &tensor)
 Copy constructor.
 
 TensorBase (Real)=delete
 
Derived variable_data () const
 
Meta operations
Derived clone () const
 
Derived detach () const
 Discard function graph.
 
Derived to (const TensorOptions &options) const
 Change tensor options.
 
Derived operator- () const
 Negation.
 
Tensor information
bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
Size base_dim () const
 Return the number of base dimensions.
 
const TraceableTensorShapebatch_sizes () const
 Return the batch size.
 
TraceableSize batch_size (Size index) const
 Return the size of a batch axis.
 
TensorShapeRef base_sizes () const
 Return the base size.
 
Size base_size (Size index) const
 Return the size of a base axis.
 
Size base_storage () const
 Return the flattened storage needed just for the base indices.
 
Getter and setter
Derived batch_index (indexing::TensorIndicesRef indices) const
 Get a tensor by slicing on the batch dimensions.
 
neml2::Tensor base_index (indexing::TensorIndicesRef indices) const
 Get a tensor by slicing on the base dimensions.
 
Derived batch_slice (Size dim, const indexing::Slice &index) const
 Get a tensor by slicing along a batch dimension.
 
neml2::Tensor base_slice (Size dim, const indexing::Slice &index) const
 Get a tensor by slicing along a base dimension.
 
void batch_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other)
 
void batch_index_put_ (indexing::TensorIndicesRef indices, Real v)
 
void base_index_put_ (indexing::TensorIndicesRef indices, const ATensor &other)
 
void base_index_put_ (indexing::TensorIndicesRef indices, Real v)
 
Modifiers
Derived batch_expand (const TraceableTensorShape &batch_shape) const
 
Derived batch_expand (const TraceableSize &batch_size, Size dim) const
 Return a new view of the tensor with values broadcast along a given batch dimension.
 
neml2::Tensor base_expand (TensorShapeRef base_shape) const
 Return a new view of the tensor with values broadcast along the base dimensions.
 
neml2::Tensor base_expand (Size base_size, Size dim) const
 Return a new view of the tensor with values broadcast along a given base dimension.
 
Derived batch_expand_as (const neml2::Tensor &other) const
 Expand the batch to have the same shape as another tensor.
 
neml2::Tensor base_expand_as (const neml2::Tensor &other) const
 Expand the base to have the same shape as another tensor.
 
Derived batch_expand_copy (const TraceableTensorShape &batch_shape) const
 Return a new tensor with values broadcast along the batch dimensions.
 
neml2::Tensor base_expand_copy (TensorShapeRef base_shape) const
 Return a new tensor with values broadcast along the base dimensions.
 
Derived batch_reshape (const TraceableTensorShape &batch_shape) const
 Reshape batch dimensions.
 
neml2::Tensor base_reshape (TensorShapeRef base_shape) const
 Reshape base dimensions.
 
Derived batch_unsqueeze (Size d) const
 Unsqueeze a batch dimension.
 
neml2::Tensor base_unsqueeze (Size d) const
 Unsqueeze a base dimension.
 
Derived batch_transpose (Size d1, Size d2) const
 Transpose two batch dimensions.
 
neml2::Tensor base_transpose (Size d1, Size d2) const
 Transpose two base dimensions.
 
neml2::Tensor base_flatten () const
 Flatten base dimensions.
 

Static Public Member Functions

static Derived empty_like (const Derived &other)
 
static Derived zeros_like (const Derived &other)
 Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived ones_like (const Derived &other)
 Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived full_like (const Derived &other, Real init)
 
static Derived linspace (const Derived &start, const Derived &end, Size nstep, Size dim=0)
 Create a new tensor by adding a new batch dimension with linear spacing between start and end.
 
static Derived logspace (const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
 log-space equivalent of the linspace named constructor
 

Constructor & Destructor Documentation

◆ TensorBase() [1/5]

template<class Derived>
TensorBase ( )
default

Special member functions.

◆ TensorBase() [2/5]

template<class Derived>
TensorBase ( const ATensor & tensor,
Size batch_dim )

Construct from another ATensor with given batch dimension.

◆ TensorBase() [3/5]

template<class Derived>
TensorBase ( const ATensor & tensor,
const TraceableTensorShape & batch_shape )

Construct from another ATensor with given batch shape.

◆ TensorBase() [4/5]

template<class Derived>
TensorBase ( const neml2::Tensor & tensor)

Copy constructor.

◆ TensorBase() [5/5]

template<class Derived>
TensorBase ( Real )
delete

Member Function Documentation

◆ base_dim()

template<class Derived>
Size base_dim ( ) const

Return the number of base dimensions.

◆ base_expand() [1/2]

template<class Derived>
neml2::Tensor base_expand ( Size base_size,
Size dim ) const

Return a new view of the tensor with values broadcast along a given base dimension.

◆ base_expand() [2/2]

template<class Derived>
neml2::Tensor base_expand ( TensorShapeRef base_shape) const

Return a new view of the tensor with values broadcast along the base dimensions.

◆ base_expand_as()

template<class Derived>
neml2::Tensor base_expand_as ( const neml2::Tensor & other) const

Expand the base to have the same shape as another tensor.

◆ base_expand_copy()

template<class Derived>
neml2::Tensor base_expand_copy ( TensorShapeRef base_shape) const

Return a new tensor with values broadcast along the base dimensions.

◆ base_flatten()

template<class Derived>
neml2::Tensor base_flatten ( ) const

Flatten base dimensions.

◆ base_index()

template<class Derived>
neml2::Tensor base_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing on the base dimensions.

◆ base_index_put_() [1/2]

template<class Derived>
void base_index_put_ ( indexing::TensorIndicesRef indices,
const ATensor & other )

Set values by slicing on the base dimensions

◆ base_index_put_() [2/2]

template<class Derived>
void base_index_put_ ( indexing::TensorIndicesRef indices,
Real v )

◆ base_reshape()

template<class Derived>
neml2::Tensor base_reshape ( TensorShapeRef base_shape) const

Reshape base dimensions.

◆ base_size()

template<class Derived>
Size base_size ( Size index) const

Return the size of a base axis.

◆ base_sizes()

template<class Derived>
TensorShapeRef base_sizes ( ) const

Return the base size.

◆ base_slice()

template<class Derived>
neml2::Tensor base_slice ( Size dim,
const indexing::Slice & index ) const

Get a tensor by slicing along a base dimension.

◆ base_storage()

template<class Derived>
Size base_storage ( ) const

Return the flattened storage needed just for the base indices.

◆ base_transpose()

template<class Derived>
neml2::Tensor base_transpose ( Size d1,
Size d2 ) const

Transpose two base dimensions.

◆ base_unsqueeze()

template<class Derived>
neml2::Tensor base_unsqueeze ( Size d) const

Unsqueeze a base dimension.

◆ batch_dim()

template<class Derived>
Size batch_dim ( ) const

Return the number of batch dimensions.

◆ batch_expand() [1/2]

template<class Derived>
Derived batch_expand ( const TraceableSize & batch_size,
Size dim ) const

Return a new view of the tensor with values broadcast along a given batch dimension.

◆ batch_expand() [2/2]

template<class Derived>
Derived batch_expand ( const TraceableTensorShape & batch_shape) const

Return a new view of the tensor with values broadcast along the batch dimensions.

◆ batch_expand_as()

template<class Derived>
Derived batch_expand_as ( const neml2::Tensor & other) const

Expand the batch to have the same shape as another tensor.

◆ batch_expand_copy()

template<class Derived>
Derived batch_expand_copy ( const TraceableTensorShape & batch_shape) const

Return a new tensor with values broadcast along the batch dimensions.

◆ batch_index()

template<class Derived>
Derived batch_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing on the batch dimensions.

◆ batch_index_put_() [1/2]

template<class Derived>
void batch_index_put_ ( indexing::TensorIndicesRef indices,
const ATensor & other )

Set values by slicing on the batch dimensions

◆ batch_index_put_() [2/2]

template<class Derived>
void batch_index_put_ ( indexing::TensorIndicesRef indices,
Real v )

◆ batch_reshape()

template<class Derived>
Derived batch_reshape ( const TraceableTensorShape & batch_shape) const

Reshape batch dimensions.

◆ batch_size()

template<class Derived>
TraceableSize batch_size ( Size index) const

Return the size of a batch axis.

◆ batch_sizes()

template<class Derived>
const TraceableTensorShape & batch_sizes ( ) const

Return the batch size.

◆ batch_slice()

template<class Derived>
Derived batch_slice ( Size dim,
const indexing::Slice & index ) const

Get a tensor by slicing along a batch dimension.

◆ batch_transpose()

template<class Derived>
Derived batch_transpose ( Size d1,
Size d2 ) const

Transpose two batch dimensions.

◆ batch_unsqueeze()

template<class Derived>
Derived batch_unsqueeze ( Size d) const

Unsqueeze a batch dimension.

◆ batched()

template<class Derived>
bool batched ( ) const

Whether the tensor is batched.

◆ clone()

template<class Derived>
Derived clone ( ) const

Clone (take ownership)

◆ detach()

template<class Derived>
Derived detach ( ) const

Discard function graph.

◆ operator-()

template<class Derived>
Derived operator- ( ) const

Negation.

◆ to()

template<class Derived>
Derived to ( const TensorOptions & options) const

Change tensor options.

◆ variable_data()

template<class Derived>
Derived variable_data ( ) const

Variable data without function graph