27#include <ATen/core/Tensor.h>
29#include "neml2/tensors/TraceableSize.h"
30#include "neml2/tensors/TraceableTensorShape.h"
31#include "neml2/tensors/functions/operators.h"
32#include "neml2/tensors/functions/logical.h"
33#include "neml2/tensors/indexing.h"
34#include "neml2/tensors/macros.h"
39template <
class Derived>
75template <
class Derived>
89 template <
class Derived2>
102 [[nodiscard]]
static Derived
empty_like(
const Derived & other);
104 [[nodiscard]]
static Derived
zeros_like(
const Derived & other);
106 [[nodiscard]]
static Derived
ones_like(
const Derived & other);
109 [[nodiscard]]
static Derived
full_like(
const Derived & other,
const CScalar & init);
112 [[nodiscard]]
static Derived
rand_like(
const Derived & other);
120 Derived
clone()
const;
124 using ATensor::detach_;
128 using ATensor::copy_;
130 using ATensor::zero_;
132 using ATensor::requires_grad;
134 using ATensor::requires_grad_;
142 using ATensor::defined;
144 using ATensor::options;
146 using ATensor::scalar_type;
148 using ATensor::device;
163 using ATensor::sizes;
182 using ATensor::index;
183 using ATensor::index_put_;
311 TraceableTensorShape _dynamic_sizes;
319#define EXPORT_TENSORBASE(T) extern template class TensorBase<T>
320FOR_ALL_TENSORBASE(EXPORT_TENSORBASE);
321#undef EXPORT_TENSORBASE
NEML2's enhanced tensor type.
Definition TensorBase.h:77
Derived batch_squeeze(Size d) const
Definition TensorBaseImpl.h:724
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:909
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:315
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
neml2::Tensor base_flatten() const
Definition TensorBaseImpl.h:891
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:858
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:849
Derived batch_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:296
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:760
TensorBase(double)=delete
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:804
Derived batch_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:335
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
Derived dynamic_flatten() const
Definition TensorBaseImpl.h:867
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:305
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:840
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:601
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:608
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:734
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:918
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:509
Derived contiguous() const
Definition TensorBaseImpl.h:126
Derived batch_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:813
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:443
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:594
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:587
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership)
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:426
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:464
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_flatten() const
Definition TensorBaseImpl.h:882
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:748
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:822
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:622
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:692
Derived batch_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:772
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:383
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:795
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:401
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:573
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:900
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:325
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:674
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
TensorBase(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition TensorBase.h:90
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:708
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:345
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:363
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:782
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:716
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:580
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:640
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:419
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:484
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
Definition DiagnosticsInterface.cxx:30
at::Tensor ATensor
Definition types.h:41
int64_t Size
Definition types.h:68
c10::Scalar CScalar
Definition types.h:42
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:70
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38