NEML2 2.0.0
|
Traceable tensor shape. More...
Traceable tensor shape.
A tensor shape can be either a concrete shape or a traceable tensor. This is useful when we need to trace a function graph and let it generalize to other batch shapes.
#include <types.h>
Public Types | |
using | Size = int64_t |
Public Member Functions | |
TraceableTensorShape (const TensorShape &shape) | |
TraceableTensorShape (TensorShapeRef shape) | |
TraceableTensorShape (Size shape) | |
TraceableTensorShape (const torch::Tensor &shape) | |
TraceableTensorShape | slice (Size start, Size end) const |
Slice the shape, semantically the same as ArrayRef::slice, but traceable. | |
TraceableTensorShape | slice (Size N) const |
Chop-off the first N elements of the shape, semantically the same as ArrayRef::slice, but traceable. | |
TensorShape | concrete () const |
torch::Tensor | as_tensor () const |
TraceableTensorShape | ( | const TensorShape & | shape | ) |
TraceableTensorShape | ( | TensorShapeRef | shape | ) |
TraceableTensorShape | ( | Size | shape | ) |
TraceableTensorShape | ( | const torch::Tensor & | shape | ) |
torch::Tensor as_tensor | ( | ) | const |
TensorShape concrete | ( | ) | const |
TraceableTensorShape slice | ( | Size | N | ) | const |
Chop-off the first N elements of the shape, semantically the same as ArrayRef::slice, but traceable.
TraceableTensorShape slice | ( | Size | start, |
Size | end ) const |
Slice the shape, semantically the same as ArrayRef::slice, but traceable.