32#include "neml2/tensors/TensorBase.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/tensors/assertions.h"
35#include "neml2/jit/types.h"
36#include "neml2/jit/utils.h"
40template <
class Derived>
49 " is smaller than the requested number of batch dimensions ",
53template <
class Derived>
56 _batch_dim(
Size(batch_shape.size())),
57 _batch_sizes(batch_shape)
62 " cannot be constructed with batch shape ",
66template <
class Derived>
72template <
class Derived>
76 return Derived(at::empty_like(other), other.batch_sizes());
79template <
class Derived>
83 return Derived(at::zeros_like(other), other.batch_sizes());
86template <
class Derived>
90 return Derived(at::ones_like(other), other.batch_sizes());
93template <
class Derived>
97 return Derived(at::full_like(other, init), other.batch_sizes());
100template <
class Derived>
107 auto res = start.batch_unsqueeze(dim);
115 net.push_back(indexing::Ellipsis);
116 net.insert(net.end(), Bd - dim, indexing::None);
117 Scalar steps(at::arange(nstep, diff.options()).index(net) / (nstep - 1));
119 res = res + steps * diff;
125template <
class Derived>
128 const Derived & start,
const Derived & end,
Size nstep,
Size dim,
Real base)
131 return Derived(at::pow(base, exponent), exponent.batch_sizes());
134template <
class Derived>
141template <
class Derived>
148template <
class Derived>
155template <
class Derived>
162template <
class Derived>
169template <
class Derived>
176template <
class Derived>
183template <
class Derived>
187 const auto i = index >= 0 ? index : index +
batch_dim();
190 if (jit::tracer::isTracing())
191 return jit::tracer::getSizeOf(*
this, i);
196template <
class Derived>
200 return sizes().slice(_batch_dim);
203template <
class Derived>
210template <
class Derived>
217template <
class Derived>
222 indices_vec.insert(indices_vec.end(),
base_dim(), indexing::Slice());
223 auto res = this->index(indices_vec);
224 return Derived(res, res.dim() -
base_dim());
227template <
class Derived>
232 indices2.insert(indices2.end(), indices.begin(), indices.end());
236template <
class Derived>
240 auto i = dim >= 0 ? dim : this->dim() + dim -
base_dim();
241 auto res = this->slice(
242 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
243 return Derived(res, res.dim() -
base_dim());
246template <
class Derived>
250 auto i = dim < 0 ? this->dim() + dim : dim +
batch_dim();
251 auto res = this->slice(
252 i, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
256template <
class Derived>
261 indices_vec.insert(indices_vec.end(),
base_dim(), indexing::Slice());
262 this->index_put_(indices_vec, other);
265template <
class Derived>
270 indices_vec.insert(indices_vec.end(),
base_dim(), indexing::Slice());
271 this->index_put_(indices_vec, v);
274template <
class Derived>
279 indices2.insert(indices2.end(), indices.begin(), indices.end());
280 this->index_put_(indices2, other);
283template <
class Derived>
288 indices2.insert(indices2.end(), indices.begin(), indices.end());
289 this->index_put_(indices2, v);
292template <
class Derived>
296 return Derived(ATensor::variable_data(),
batch_sizes());
299template <
class Derived>
305 net.insert(net.end(),
base_dim(), -1);
308 for (
Size i = 0; i < (
Size)batch_shape.size(); ++i)
309 if (
const auto *
const si = batch_shape[i].traceable())
310 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
312 return Derived(expand(net), batch_shape);
315template <
class Derived>
319 auto i = dim >= 0 ? dim : this->dim() + dim -
base_dim();
322 return Derived(*
this);
328template <
class Derived>
336 auto net = base_shape.vec();
337 net.insert(net.begin(),
batch_dim(), -1);
341template <
class Derived>
349 auto net = std::vector<Size>(this->dim(), -1);
350 auto i = dim < 0 ? this->dim() + dim : dim +
batch_dim();
355template <
class Derived>
362template <
class Derived>
369template <
class Derived>
373 return Derived(
batch_expand(batch_shape).contiguous(), batch_shape);
376template <
class Derived>
383template <
class Derived>
388 for (
Size i = 0; i < (
Size)batch_shape.size(); ++i)
389 if (
const auto *
const si = batch_shape[i].traceable())
390 jit::tracer::ArgumentStash::stashIntArrayRefElem(
391 "shape", batch_shape.size() +
base_dim(), i, *si);
396template <
class Derived>
403 for (
Size i = 0; i < (
Size)batch_shape.size(); ++i)
404 if (
const auto *
const si = batch_shape[i].traceable())
405 jit::tracer::ArgumentStash::stashIntArrayRefElem(
406 "shape", batch_shape.size() + base_shape.size(), i, *si);
412template <
class Derived>
416 auto d2 = d >= 0 ? d : d -
base_dim();
417 return Derived(unsqueeze(d2), _batch_dim + 1);
420template <
class Derived>
428template <
class Derived>
432 return Derived(ATensor::transpose(d1 < 0 ? d1 -
base_dim() : d1, d2 < 0 ? d2 -
base_dim() : d2),
436template <
class Derived>
441 ATensor::transpose(d1 < 0 ? d1 : _batch_dim + d1, d2 < 0 ? d2 : _batch_dim + d2),
445template <
class Derived>
455template <
class Derived>
203template <
class Derived> {
…}
191 return jit::tracer::getSizeOf(*
this, i); {
…}
169template <
class Derived> {
…}
148template <
class Derived> {
…}
100template <
class Derived> {
…}
97 return Derived(at::full_like(other, init), other.batch_sizes()); {
…}
72template <
class Derived> {
…}
Scalar.
Definition Scalar.h:38
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:447
Derived batch_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a batch dimension.
Definition TensorBaseImpl.h:238
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:385
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:219
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:438
neml2::Tensor base_slice(Size dim, const indexing::Slice &index) const
Get a tensor by slicing along a base dimension.
Definition TensorBaseImpl.h:248
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:185
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:212
TensorBase()=default
Special member functions.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:157
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:164
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:143
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:457
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:430
Derived batch_expand_as(const neml2::Tensor &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBaseImpl.h:357
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBaseImpl.h:364
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:378
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:422
Derived clone() const
Definition TensorBaseImpl.h:136
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:205
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:150
const TraceableTensorShape & batch_sizes() const
Definition TensorBaseImpl.h:178
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:276
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:258
MillerIndex batch_unsqueeze(Size d) const
Definition TensorBaseImpl.h:414
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:371
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:171
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:198
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:398
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:301
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:330
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:229
Derived variable_data() const
Definition TensorBaseImpl.h:294
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:127
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:81
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:74
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:102
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:95
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:88
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:78
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition types.h:77
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:30
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
TensorShape add_shapes(const S &...)
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
at::Tensor ATensor
Definition types.h:42
double Real
Definition types.h:68
int64_t Size
Definition types.h:69
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:72
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:78