NEML2 2.0.0
Loading...
Searching...
No Matches
TraceableTensorShape Struct Reference

Traceable tensor shape. More...

Detailed Description

Traceable tensor shape.

A tensor shape can be either a concrete shape or a traceable tensor. This is useful when we need to trace a function graph and let it generalize to other batch shapes.

#include <types.h>

Inheritance diagram for TraceableTensorShape:

Public Types

using Size = int64_t
 

Public Member Functions

 TraceableTensorShape (const TensorShape &shape)
 
 TraceableTensorShape (TensorShapeRef shape)
 
 TraceableTensorShape (Size shape)
 
 TraceableTensorShape (const torch::Tensor &shape)
 
TraceableTensorShape slice (Size start, Size end) const
 Slice the shape, semantically the same as ArrayRef::slice, but traceable.
 
TraceableTensorShape slice (Size N) const
 Chop-off the first N elements of the shape, semantically the same as ArrayRef::slice, but traceable.
 
TensorShape concrete () const
 
torch::Tensor as_tensor () const
 

Member Typedef Documentation

◆ Size

Constructor & Destructor Documentation

◆ TraceableTensorShape() [1/4]

◆ TraceableTensorShape() [2/4]

◆ TraceableTensorShape() [3/4]

◆ TraceableTensorShape() [4/4]

TraceableTensorShape ( const torch::Tensor & shape)

Member Function Documentation

◆ as_tensor()

torch::Tensor as_tensor ( ) const
Returns
the shape represented as a scalar tensor (possibly traceable)

◆ concrete()

TensorShape concrete ( ) const
Returns
the concrete shape (without any traceable information)

◆ slice() [1/2]

TraceableTensorShape slice ( Size N) const

Chop-off the first N elements of the shape, semantically the same as ArrayRef::slice, but traceable.

◆ slice() [2/2]

TraceableTensorShape slice ( Size start,
Size end ) const

Slice the shape, semantically the same as ArrayRef::slice, but traceable.