27#include <ATen/core/Tensor.h>
29#include "neml2/jit/TraceableTensorShape.h"
30#include "neml2/tensors/functions/operators.h"
31#include "neml2/tensors/indexing.h"
36template <
class Derived>
48template <
class Derived>
71 [[nodiscard]]
static Derived
empty_like(
const Derived & other);
73 [[nodiscard]]
static Derived
zeros_like(
const Derived & other);
75 [[nodiscard]]
static Derived
ones_like(
const Derived & other);
78 [[nodiscard]]
static Derived
full_like(
const Derived & other,
const CScalar & init);
99 [[nodiscard]]
static Derived
100 linspace(
const Derived & start,
const Derived & end,
Size nstep,
Size dim = 0);
102 [[nodiscard]]
static Derived
logspace(
const Derived & start,
112 Derived
clone()
const;
116 using ATensor::detach_;
120 using ATensor::copy_;
122 using ATensor::zero_;
124 using ATensor::requires_grad;
126 using ATensor::requires_grad_;
134 using ATensor::options;
136 using ATensor::scalar_type;
138 using ATensor::device;
142 using ATensor::sizes;
166 using ATensor::index;
167 using ATensor::index_put_;
NEML2's enhanced tensor type.
Definition TensorBase.h:50
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:451
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:242
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:389
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:223
TensorBase(double)=delete
TraceableTensorShape _batch_sizes
Traceable batch sizes.
Definition TensorBase.h:226
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:442
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:252
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:189
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:216
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:161
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:168
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:147
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:461
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:434
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:361
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:368
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:382
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:426
Derived clone() const
Definition TensorBaseImpl.h:140
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:209
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:154
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:182
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:280
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:262
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:418
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:375
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:202
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:402
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:305
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:334
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:233
Derived variable_data() const
Definition TensorBaseImpl.h:298
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:99
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:85
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:78
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:106
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, const CScalar &base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:131
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:92
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
Definition DiagnosticsInterface.cxx:30
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
at::Tensor ATensor
Definition types.h:38
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38