27#include <ATen/core/Tensor.h>
28#include "neml2/jit/TraceableTensorShape.h"
29#include "neml2/tensors/shape_utils.h"
30#include "neml2/tensors/functions/operators.h"
35template <
class Derived>
47template <
class Derived>
68 [[nodiscard]]
static Derived
empty_like(
const Derived & other);
70 [[nodiscard]]
static Derived
zeros_like(
const Derived & other);
72 [[nodiscard]]
static Derived
ones_like(
const Derived & other);
75 [[nodiscard]]
static Derived
full_like(
const Derived & other,
Real init);
96 [[nodiscard]]
static Derived
97 linspace(
const Derived & start,
const Derived & end,
Size nstep,
Size dim = 0);
99 [[nodiscard]]
static Derived
106 Derived
clone()
const;
110 using ATensor::detach_;
114 using ATensor::copy_;
116 using ATensor::zero_;
118 using ATensor::requires_grad;
120 using ATensor::requires_grad_;
128 using ATensor::options;
130 using ATensor::scalar_type;
132 using ATensor::device;
136 using ATensor::sizes;
160 using ATensor::index;
161 using ATensor::index_put_;
220 Size _batch_dim = {};
NEML2's enhanced tensor type.
Definition TensorBase.h:49
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:447
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:238
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:385
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:219
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:438
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:248
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:185
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:212
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:157
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:164
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:143
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:457
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:430
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:357
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:364
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:378
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:422
Derived clone() const
Definition TensorBaseImpl.h:136
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:205
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:150
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:178
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:276
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:258
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:414
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:371
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:171
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:198
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:398
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:301
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:330
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:229
Derived variable_data() const
Definition TensorBaseImpl.h:294
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:127
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:81
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:74
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:102
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:95
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:88
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:78
Definition DiagnosticsInterface.cxx:30
at::Tensor ATensor
Definition types.h:42
double Real
Definition types.h:68
int64_t Size
Definition types.h:69
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:72
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38