NEML2 2.0.0
|
NEML2's enhanced tensor type. More...
NEML2's enhanced tensor type.
neml2::TensorBase derives from torch::Tensor and clearly distinguishes between "batched" dimensions from other dimensions. The shape of the "batched" dimensions is called the batch size, and the shape of the rest dimensions is called the base size.
#include <TensorBase.h>
Public Member Functions | |
TensorBase ()=default | |
Default constructor. | |
TensorBase (const torch::Tensor &tensor, Size batch_dim) | |
Construct from another torch::Tensor with given batch dimension. | |
TensorBase (const torch::Tensor &tensor, const TraceableTensorShape &batch_shape) | |
Construct from another torch::Tensor with given batch shape. | |
template<class Derived2 > | |
TensorBase (const TensorBase< Derived2 > &tensor) | |
Copy constructor. | |
TensorBase (Real)=delete | |
Derived | variable_data () const |
Meta operations | |
Clone (take ownership) | |
Derived | clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const |
Derived | detach () const |
Discard function graph. | |
Derived | to (const torch::TensorOptions &options) const |
Change tensor options. | |
Derived | operator- () const |
Negation. | |
Tensor information | |
Tensor options | |
bool | batched () const |
Whether the tensor is batched. | |
Size | batch_dim () const |
Return the number of batch dimensions. | |
Size | base_dim () const |
Return the number of base dimensions. | |
const TraceableTensorShape & | batch_sizes () const |
Return the batch size. | |
TraceableSize | batch_size (Size index) const |
Return the size of a batch axis. | |
TensorShapeRef | base_sizes () const |
Return the base size. | |
Size | base_size (Size index) const |
Return the size of a base axis. | |
Size | base_storage () const |
Return the flattened storage needed just for the base indices. | |
Getter and setter | |
Regular tensor indexing | |
Derived | batch_index (indexing::TensorIndicesRef indices) const |
Get a tensor by slicing on the batch dimensions. | |
neml2::Tensor | base_index (indexing::TensorIndicesRef indices) const |
Get a tensor by slicing on the base dimensions. | |
void | batch_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other) |
void | batch_index_put_ (indexing::TensorIndicesRef indices, Real v) |
void | base_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other) |
void | base_index_put_ (indexing::TensorIndicesRef indices, Real v) |
Modifiers | |
Return a new view of the tensor with values broadcast along the batch dimensions. | |
Derived | batch_expand (const TraceableTensorShape &batch_shape) const |
Derived | batch_expand (const TraceableSize &batch_size, Size dim) const |
Return a new view of the tensor with values broadcast along a given batch dimension. | |
neml2::Tensor | base_expand (TensorShapeRef base_shape) const |
Return a new view of the tensor with values broadcast along the base dimensions. | |
neml2::Tensor | base_expand (Size base_size, Size dim) const |
Return a new view of the tensor with values broadcast along a given base dimension. | |
template<class Derived2 > | |
Derived | batch_expand_as (const Derived2 &other) const |
Expand the batch to have the same shape as another tensor. | |
template<class Derived2 > | |
Derived2 | base_expand_as (const Derived2 &other) const |
Expand the base to have the same shape as another tensor. | |
Derived | batch_expand_copy (const TraceableTensorShape &batch_shape) const |
Return a new tensor with values broadcast along the batch dimensions. | |
neml2::Tensor | base_expand_copy (TensorShapeRef base_shape) const |
Return a new tensor with values broadcast along the base dimensions. | |
Derived | batch_reshape (const TraceableTensorShape &batch_shape) const |
Reshape batch dimensions. | |
neml2::Tensor | base_reshape (TensorShapeRef base_shape) const |
Reshape base dimensions. | |
Derived | batch_unsqueeze (Size d) const |
Unsqueeze a batch dimension. | |
neml2::Tensor | base_unsqueeze (Size d) const |
Unsqueeze a base dimension. | |
Derived | batch_transpose (Size d1, Size d2) const |
Transpose two batch dimensions. | |
neml2::Tensor | base_transpose (Size d1, Size d2) const |
Transpose two base dimensions. | |
neml2::Tensor | base_flatten () const |
Flatten base dimensions. | |
Static Public Member Functions | |
static Derived | empty_like (const Derived &other) |
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc. | |
static Derived | zeros_like (const Derived &other) |
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc. | |
static Derived | ones_like (const Derived &other) |
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc. | |
static Derived | full_like (const Derived &other, Real init) |
static Derived | linspace (const Derived &start, const Derived &end, Size nstep, Size dim=0) |
Create a new tensor by adding a new batch dimension with linear spacing between start and end . | |
static Derived | logspace (const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10) |
log-space equivalent of the linspace named constructor | |
|
default |
Default constructor.
TensorBase | ( | const torch::Tensor & | tensor, |
Size | batch_dim ) |
Construct from another torch::Tensor with given batch dimension.
TensorBase | ( | const torch::Tensor & | tensor, |
const TraceableTensorShape & | batch_shape ) |
Construct from another torch::Tensor with given batch shape.
TensorBase | ( | const TensorBase< Derived2 > & | tensor | ) |
Copy constructor.
|
delete |
neml2::Tensor base_expand | ( | Size | base_size, |
Size | dim ) const |
Return a new view of the tensor with values broadcast along a given base dimension.
neml2::Tensor base_expand | ( | TensorShapeRef | base_shape | ) | const |
Return a new view of the tensor with values broadcast along the base dimensions.
Expand the base to have the same shape as another tensor.
neml2::Tensor base_expand_copy | ( | TensorShapeRef | base_shape | ) | const |
Return a new tensor with values broadcast along the base dimensions.
neml2::Tensor base_flatten | ( | ) | const |
Flatten base dimensions.
neml2::Tensor base_index | ( | indexing::TensorIndicesRef | indices | ) | const |
Get a tensor by slicing on the base dimensions.
void base_index_put_ | ( | indexing::TensorIndicesRef | indices, |
const torch::Tensor & | other ) |
Set values by slicing on the base dimensions
void base_index_put_ | ( | indexing::TensorIndicesRef | indices, |
Real | v ) |
neml2::Tensor base_reshape | ( | TensorShapeRef | base_shape | ) | const |
Reshape base dimensions.
TensorShapeRef base_sizes | ( | ) | const |
Return the base size.
Return the flattened storage needed just for the base indices.
neml2::Tensor base_transpose | ( | Size | d1, |
Size | d2 ) const |
Transpose two base dimensions.
neml2::Tensor base_unsqueeze | ( | Size | d | ) | const |
Unsqueeze a base dimension.
Derived batch_expand | ( | const TraceableSize & | batch_size, |
Size | dim ) const |
Return a new view of the tensor with values broadcast along a given batch dimension.
Derived batch_expand | ( | const TraceableTensorShape & | batch_shape | ) | const |
Expand the batch to have the same shape as another tensor.
Derived batch_expand_copy | ( | const TraceableTensorShape & | batch_shape | ) | const |
Return a new tensor with values broadcast along the batch dimensions.
Derived batch_index | ( | indexing::TensorIndicesRef | indices | ) | const |
Get a tensor by slicing on the batch dimensions.
void batch_index_put_ | ( | indexing::TensorIndicesRef | indices, |
const torch::Tensor & | other ) |
Set values by slicing on the batch dimensions
void batch_index_put_ | ( | indexing::TensorIndicesRef | indices, |
Real | v ) |
Derived batch_reshape | ( | const TraceableTensorShape & | batch_shape | ) | const |
Reshape batch dimensions.
TraceableSize batch_size | ( | Size | index | ) | const |
Return the size of a batch axis.
const TraceableTensorShape & batch_sizes | ( | ) | const |
Return the batch size.
Transpose two batch dimensions.
Derived clone | ( | torch::MemoryFormat | memory_format = torch::MemoryFormat::Contiguous | ) | const |
Derived detach | ( | ) | const |
Discard function graph.
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Full tensor like another, i.e. same batch and base shapes, same tensor options, etc., but filled with a different value
|
static |
Create a new tensor by adding a new batch dimension with linear spacing between start
and end
.
start
and end
must be broadcastable. The new batch dimension will be added at the user-specified dimension dim
which defaults to 0.
For example, if start
has shape (3, 2; 5, 5)
and end
has shape (3, 1; 5, 5)
, then
will have shape (3, 100, 2; 5, 5)
, note the location of the new dimension and the broadcasting.
start | The starting tensor |
end | The ending tensor |
nstep | The number of steps with even spacing along the new dimension |
dim | Where to insert the new dimension |
|
static |
log-space equivalent of the linspace named constructor
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Change tensor options.