27#include <ATen/core/Tensor.h>
29#include "neml2/tensors/TraceableSize.h"
30#include "neml2/tensors/TraceableTensorShape.h"
31#include "neml2/tensors/functions/operators.h"
32#include "neml2/tensors/functions/logical.h"
33#include "neml2/tensors/indexing.h"
34#include "neml2/tensors/macros.h"
39template <
class Derived>
75template <
class Derived>
89 template <
class Derived2>
102 [[nodiscard]]
static Derived
empty_like(
const Derived & other);
104 [[nodiscard]]
static Derived
zeros_like(
const Derived & other);
106 [[nodiscard]]
static Derived
ones_like(
const Derived & other);
109 [[nodiscard]]
static Derived
full_like(
const Derived & other,
const CScalar & init);
112 [[nodiscard]]
static Derived
rand_like(
const Derived & other);
120 Derived
clone()
const;
124 using ATensor::detach_;
128 using ATensor::copy_;
130 using ATensor::zero_;
132 using ATensor::requires_grad;
134 using ATensor::requires_grad_;
142 using ATensor::defined;
144 using ATensor::options;
146 using ATensor::scalar_type;
148 using ATensor::device;
163 using ATensor::sizes;
182 using ATensor::index;
183 using ATensor::index_put_;
311#define EXPORT_TENSORBASE(T) extern template class TensorBase<T>
312FOR_ALL_TENSORBASE(EXPORT_TENSORBASE);
313#undef EXPORT_TENSORBASE
NEML2's enhanced tensor type.
Definition TensorBase.h:77
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:834
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:306
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
neml2::Tensor base_flatten() const
Definition TensorBaseImpl.h:816
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:447
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:783
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:713
TensorBase(double)=delete
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:747
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
Derived dynamic_flatten() const
Definition TensorBaseImpl.h:792
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:296
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:774
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:564
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:571
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:687
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:843
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:472
Derived contiguous() const
Definition TensorBaseImpl.h:126
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:406
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:557
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:550
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership)
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:389
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:427
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_flatten() const
Definition TensorBaseImpl.h:807
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:701
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:756
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:585
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:364
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:738
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:536
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:825
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:316
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:637
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
TensorBase(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition TensorBase.h:90
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:671
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:326
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:618
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:344
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:725
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:679
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:543
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:603
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:382
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
Definition DiagnosticsInterface.cxx:30
at::Tensor ATensor
Definition types.h:38
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38