NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Pages
TensorBase< Derived > Class Template Reference

NEML2's enhanced tensor type. More...

Detailed Description

template<class Derived>
class neml2::TensorBase< Derived >

NEML2's enhanced tensor type.

neml2::TensorBase derives from torch::Tensor and clearly distinguishes between "batched" dimensions from other dimensions. The shape of the "batched" dimensions is called the batch size, and the shape of the rest dimensions is called the base size.

#include <TensorBase.h>

Inheritance diagram for TensorBase< Derived >:

Public Member Functions

 TensorBase ()=default
 Default constructor.
 
 TensorBase (const torch::Tensor &tensor, Size batch_dim)
 Construct from another torch::Tensor with given batch dimension.
 
 TensorBase (const torch::Tensor &tensor, const TraceableTensorShape &batch_shape)
 Construct from another torch::Tensor with given batch shape.
 
template<class Derived2 >
 TensorBase (const TensorBase< Derived2 > &tensor)
 Copy constructor.
 
 TensorBase (Real)=delete
 
Derived variable_data () const
 
Meta operations

Clone (take ownership)

Derived clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
 
Derived detach () const
 Discard function graph.
 
Derived to (const torch::TensorOptions &options) const
 Change tensor options.
 
Derived operator- () const
 Negation.
 
Tensor information

Tensor options

bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
Size base_dim () const
 Return the number of base dimensions.
 
const TraceableTensorShapebatch_sizes () const
 Return the batch size.
 
TraceableSize batch_size (Size index) const
 Return the size of a batch axis.
 
TensorShapeRef base_sizes () const
 Return the base size.
 
Size base_size (Size index) const
 Return the size of a base axis.
 
Size base_storage () const
 Return the flattened storage needed just for the base indices.
 
Getter and setter

Regular tensor indexing

Derived batch_index (indexing::TensorIndicesRef indices) const
 Get a tensor by slicing on the batch dimensions.
 
neml2::Tensor base_index (indexing::TensorIndicesRef indices) const
 Get a tensor by slicing on the base dimensions.
 
void batch_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other)
 
void batch_index_put_ (indexing::TensorIndicesRef indices, Real v)
 
void base_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other)
 
void base_index_put_ (indexing::TensorIndicesRef indices, Real v)
 
Modifiers

Return a new view of the tensor with values broadcast along the batch dimensions.

Derived batch_expand (const TraceableTensorShape &batch_shape) const
 
Derived batch_expand (const TraceableSize &batch_size, Size dim) const
 Return a new view of the tensor with values broadcast along a given batch dimension.
 
neml2::Tensor base_expand (TensorShapeRef base_shape) const
 Return a new view of the tensor with values broadcast along the base dimensions.
 
neml2::Tensor base_expand (Size base_size, Size dim) const
 Return a new view of the tensor with values broadcast along a given base dimension.
 
template<class Derived2 >
Derived batch_expand_as (const Derived2 &other) const
 Expand the batch to have the same shape as another tensor.
 
template<class Derived2 >
Derived2 base_expand_as (const Derived2 &other) const
 Expand the base to have the same shape as another tensor.
 
Derived batch_expand_copy (const TraceableTensorShape &batch_shape) const
 Return a new tensor with values broadcast along the batch dimensions.
 
neml2::Tensor base_expand_copy (TensorShapeRef base_shape) const
 Return a new tensor with values broadcast along the base dimensions.
 
Derived batch_reshape (const TraceableTensorShape &batch_shape) const
 Reshape batch dimensions.
 
neml2::Tensor base_reshape (TensorShapeRef base_shape) const
 Reshape base dimensions.
 
Derived batch_unsqueeze (Size d) const
 Unsqueeze a batch dimension.
 
neml2::Tensor base_unsqueeze (Size d) const
 Unsqueeze a base dimension.
 
Derived batch_transpose (Size d1, Size d2) const
 Transpose two batch dimensions.
 
neml2::Tensor base_transpose (Size d1, Size d2) const
 Transpose two base dimensions.
 
neml2::Tensor base_flatten () const
 Flatten base dimensions.
 

Static Public Member Functions

static Derived empty_like (const Derived &other)
 Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived zeros_like (const Derived &other)
 Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived ones_like (const Derived &other)
 Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived full_like (const Derived &other, Real init)
 
static Derived linspace (const Derived &start, const Derived &end, Size nstep, Size dim=0)
 Create a new tensor by adding a new batch dimension with linear spacing between start and end.
 
static Derived logspace (const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
 log-space equivalent of the linspace named constructor
 

Constructor & Destructor Documentation

◆ TensorBase() [1/5]

template<class Derived >
TensorBase ( )
default

Default constructor.

◆ TensorBase() [2/5]

template<class Derived >
TensorBase ( const torch::Tensor & tensor,
Size batch_dim )

Construct from another torch::Tensor with given batch dimension.

◆ TensorBase() [3/5]

template<class Derived >
TensorBase ( const torch::Tensor & tensor,
const TraceableTensorShape & batch_shape )

Construct from another torch::Tensor with given batch shape.

◆ TensorBase() [4/5]

template<class Derived >
template<class Derived2 >
TensorBase ( const TensorBase< Derived2 > & tensor)

Copy constructor.

◆ TensorBase() [5/5]

template<class Derived >
TensorBase ( Real )
delete

Member Function Documentation

◆ base_dim()

template<class Derived >
Size base_dim ( ) const

Return the number of base dimensions.

◆ base_expand() [1/2]

template<class Derived >
neml2::Tensor base_expand ( Size base_size,
Size dim ) const

Return a new view of the tensor with values broadcast along a given base dimension.

◆ base_expand() [2/2]

template<class Derived >
neml2::Tensor base_expand ( TensorShapeRef base_shape) const

Return a new view of the tensor with values broadcast along the base dimensions.

◆ base_expand_as()

template<class Derived >
template<class Derived2 >
Derived2 base_expand_as ( const Derived2 & other) const

Expand the base to have the same shape as another tensor.

◆ base_expand_copy()

template<class Derived >
neml2::Tensor base_expand_copy ( TensorShapeRef base_shape) const

Return a new tensor with values broadcast along the base dimensions.

◆ base_flatten()

template<class Derived >
neml2::Tensor base_flatten ( ) const

Flatten base dimensions.

◆ base_index()

template<class Derived >
neml2::Tensor base_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing on the base dimensions.

◆ base_index_put_() [1/2]

template<class Derived >
void base_index_put_ ( indexing::TensorIndicesRef indices,
const torch::Tensor & other )

Set values by slicing on the base dimensions

◆ base_index_put_() [2/2]

template<class Derived >
void base_index_put_ ( indexing::TensorIndicesRef indices,
Real v )

◆ base_reshape()

template<class Derived >
neml2::Tensor base_reshape ( TensorShapeRef base_shape) const

Reshape base dimensions.

◆ base_size()

template<class Derived >
Size base_size ( Size index) const

Return the size of a base axis.

◆ base_sizes()

template<class Derived >
TensorShapeRef base_sizes ( ) const

Return the base size.

◆ base_storage()

template<class Derived >
Size base_storage ( ) const

Return the flattened storage needed just for the base indices.

◆ base_transpose()

template<class Derived >
neml2::Tensor base_transpose ( Size d1,
Size d2 ) const

Transpose two base dimensions.

◆ base_unsqueeze()

template<class Derived >
neml2::Tensor base_unsqueeze ( Size d) const

Unsqueeze a base dimension.

◆ batch_dim()

template<class Derived >
Size batch_dim ( ) const

Return the number of batch dimensions.

◆ batch_expand() [1/2]

template<class Derived >
Derived batch_expand ( const TraceableSize & batch_size,
Size dim ) const

Return a new view of the tensor with values broadcast along a given batch dimension.

◆ batch_expand() [2/2]

template<class Derived >
Derived batch_expand ( const TraceableTensorShape & batch_shape) const

◆ batch_expand_as()

template<class Derived >
template<class Derived2 >
Derived batch_expand_as ( const Derived2 & other) const

Expand the batch to have the same shape as another tensor.

◆ batch_expand_copy()

template<class Derived >
Derived batch_expand_copy ( const TraceableTensorShape & batch_shape) const

Return a new tensor with values broadcast along the batch dimensions.

◆ batch_index()

template<class Derived >
Derived batch_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing on the batch dimensions.

◆ batch_index_put_() [1/2]

template<class Derived >
void batch_index_put_ ( indexing::TensorIndicesRef indices,
const torch::Tensor & other )

Set values by slicing on the batch dimensions

◆ batch_index_put_() [2/2]

template<class Derived >
void batch_index_put_ ( indexing::TensorIndicesRef indices,
Real v )

◆ batch_reshape()

template<class Derived >
Derived batch_reshape ( const TraceableTensorShape & batch_shape) const

Reshape batch dimensions.

◆ batch_size()

template<class Derived >
TraceableSize batch_size ( Size index) const

Return the size of a batch axis.

◆ batch_sizes()

template<class Derived >
const TraceableTensorShape & batch_sizes ( ) const

Return the batch size.

◆ batch_transpose()

template<class Derived >
Derived batch_transpose ( Size d1,
Size d2 ) const

Transpose two batch dimensions.

◆ batch_unsqueeze()

template<class Derived >
Derived batch_unsqueeze ( Size d) const

Unsqueeze a batch dimension.

◆ batched()

template<class Derived >
bool batched ( ) const

Whether the tensor is batched.

◆ clone()

template<class Derived >
Derived clone ( torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const

◆ detach()

template<class Derived >
Derived detach ( ) const

Discard function graph.

◆ empty_like()

template<class Derived >
Derived empty_like ( const Derived & other)
static

Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.

◆ full_like()

template<class Derived >
Derived full_like ( const Derived & other,
Real init )
static

Full tensor like another, i.e. same batch and base shapes, same tensor options, etc., but filled with a different value

◆ linspace()

template<class Derived >
Derived linspace ( const Derived & start,
const Derived & end,
Size nstep,
Size dim = 0 )
static

Create a new tensor by adding a new batch dimension with linear spacing between start and end.

start and end must be broadcastable. The new batch dimension will be added at the user-specified dimension dim which defaults to 0.

For example, if start has shape (3, 2; 5, 5) and end has shape (3, 1; 5, 5), then

linspace(start, end, 100, 1);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:96

will have shape (3, 100, 2; 5, 5), note the location of the new dimension and the broadcasting.

Parameters
startThe starting tensor
endThe ending tensor
nstepThe number of steps with even spacing along the new dimension
dimWhere to insert the new dimension
Returns
Tensor Linearly spaced tensor

◆ logspace()

template<class Derived >
Derived logspace ( const Derived & start,
const Derived & end,
Size nstep,
Size dim = 0,
Real base = 10 )
static

log-space equivalent of the linspace named constructor

◆ ones_like()

template<class Derived >
Derived ones_like ( const Derived & other)
static

Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.

◆ operator-()

template<class Derived >
Derived operator- ( ) const

Negation.

◆ to()

template<class Derived >
Derived to ( const torch::TensorOptions & options) const

Change tensor options.

◆ variable_data()

template<class Derived >
Derived variable_data ( ) const

Variable data without function graph

◆ zeros_like()

template<class Derived >
Derived zeros_like ( const Derived & other)
static

Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.