32#include <torch/csrc/jit/frontend/tracer.h>
34#include "neml2/tensors/TraceableTensorShape.h"
35#include "neml2/tensors/Tensor.h"
36#include "neml2/tensors/TensorBase.h"
37#include "neml2/tensors/Scalar.h"
38#include "neml2/tensors/jit.h"
39#include "neml2/tensors/shape_utils.h"
40#include "neml2/misc/assertions.h"
44using namespace torch::jit;
49template <
class Derived>
58template <
class Derived>
63 _dynamic_sizes(std::move(dynamic_shape)),
69template <
class Derived>
76 " is not sufficient for the requested number of dynamic dimensions (",
78 ") and intermediate dimensions (",
84 " is incompatible with dynamic shape ",
86 ". The leading dimensions must match.");
89template <
class Derived>
93 return Derived(at::empty_like(other), other.dynamic_sizes(), other.intmd_dim());
96template <
class Derived>
100 return Derived(at::zeros_like(other), other.dynamic_sizes(), other.intmd_dim());
103template <
class Derived>
107 return Derived(at::ones_like(other), other.dynamic_sizes(), other.intmd_dim());
110template <
class Derived>
114 return Derived(at::full_like(other, init), other.dynamic_sizes(), other.intmd_dim());
117template <
class Derived>
121 return Derived(at::rand_like(other), other.dynamic_sizes(), other.intmd_dim());
124template <
class Derived>
131template <
class Derived>
138template <
class Derived>
145template <
class Derived>
152template <
class Derived>
159template <
class Derived>
166template <
class Derived>
170 return static_cast<Size>(_dynamic_sizes.size());
173template <
class Derived>
180template <
class Derived>
187template <
class Derived>
194template <
class Derived>
201template <
class Derived>
205 return _dynamic_sizes;
208template <
class Derived>
215template <
class Derived>
222template <
class Derived>
228 return _dynamic_sizes[i];
232template <
class Derived>
240template <
class Derived>
245 return _dynamic_sizes[i];
248template <
class Derived>
256template <
class Derived>
264template <
class Derived>
269 indices_vec.insert(indices_vec.end(),
static_dim(), indexing::Slice());
270 auto res = this->index(indices_vec);
274template <
class Derived>
279 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
280 indices_vec.insert(indices_vec.end(),
base_dim(), indexing::Slice());
281 auto res = this->index(indices_vec);
285template <
class Derived>
290 indices2.insert(indices2.end(), indices.begin(), indices.end());
294template <
class Derived>
299 "batch_index is only supported when there are no intermediate dimensions.");
303template <
class Derived>
308 auto res = this->slice(
309 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
313template <
class Derived>
318 auto res = this->slice(
319 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
323template <
class Derived>
328 auto res = this->slice(
329 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
333template <
class Derived>
343template <
class Derived>
348 indices_vec.insert(indices_vec.end(),
static_dim(), indexing::Slice());
349 this->index_put_(indices_vec, other);
352template <
class Derived>
357 indices_vec.insert(indices_vec.end(),
static_dim(), indexing::Slice());
358 this->index_put_(indices_vec, v);
361template <
class Derived>
366 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
367 indices_vec.insert(indices_vec.end(),
base_dim(), indexing::Slice());
368 this->index_put_(indices_vec, other);
371template <
class Derived>
376 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
377 indices_vec.insert(indices_vec.end(),
base_dim(), indexing::Slice());
378 this->index_put_(indices_vec, v);
381template <
class Derived>
386 indices2.insert(indices2.end(), indices.begin(), indices.end());
387 this->index_put_(indices2, other);
390template <
class Derived>
395 indices2.insert(indices2.end(), indices.begin(), indices.end());
396 this->index_put_(indices2, v);
399template <
class Derived>
404 "batch_index_put_ is only supported when there are no intermediate dimensions.");
408template <
class Derived>
413 "batch_index_put_ is only supported when there are no intermediate dimensions.");
417template <
class Derived>
424template <
class Derived>
433 if (jit::tracer::isTracing())
434 for (std::size_t i = 0; i < shape.size(); ++i)
435 if (
const auto *
const si = shape[i].traceable())
436 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
438 return Derived(expand(net), shape,
intmd_dim());
441template <
class Derived>
450 "Invalid intermediate shape to expand. Expected at least ",
457 net.insert(net.end(), shape.begin(), shape.end());
458 net.insert(net.end(),
base_dim(), -1);
462template <
class Derived>
471 "Invalid base shape to expand. Expected at least ",
478 net.insert(net.end(), shape.begin(), shape.end());
482template <
class Derived>
489 "Invalid intermediate shape to expand. Expected at least ",
496 net.insert(net.end(),
base_dim(), -1);
499 if (jit::tracer::isTracing())
500 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
501 if (
const auto *
const si = dynamic_shape[i].traceable())
502 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
504 return Derived(tmp.expand(net), dynamic_shape, tmp.intmd_dim());
507template <
class Derived>
517 "Invalid intermediate shape to expand. Expected at least ",
521 "Invalid base shape to expand. Expected at least ",
525 tmp = tmp.base_unsqueeze(0, base_shape.size() -
base_dim());
532template <
class Derived>
537 return Derived(*
this);
545template <
class Derived>
558template <
class Derived>
571template <
class Derived>
578template <
class Derived>
585template <
class Derived>
592template <
class Derived>
599template <
class Derived>
606template <
class Derived>
611 if (jit::tracer::isTracing())
612 for (std::size_t i = 0; i < shape.size(); ++i)
613 if (
const auto *
const si = shape[i].traceable())
614 jit::tracer::ArgumentStash::stashIntArrayRefElem(
615 "shape", shape.size() +
static_dim(), i, *si);
620template <
class Derived>
627 if (jit::tracer::isTracing())
629 if (
const auto *
const si =
dynamic_size(i).traceable())
630 jit::tracer::ArgumentStash::stashIntArrayRefElem(
638template <
class Derived>
643 if (jit::tracer::isTracing())
645 if (
const auto *
const si =
dynamic_size(i).traceable())
646 jit::tracer::ArgumentStash::stashIntArrayRefElem(
647 "shape",
batch_dim() + shape.size(), i, *si);
653template <
class Derived>
661 if (jit::tracer::isTracing())
662 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
663 if (
const auto *
const si = dynamic_shape[i].traceable())
664 jit::tracer::ArgumentStash::stashIntArrayRefElem(
672template <
class Derived>
679 if (jit::tracer::isTracing())
681 if (
const auto *
const si =
dynamic_size(i).traceable())
682 jit::tracer::ArgumentStash::stashIntArrayRefElem(
690template <
class Derived>
696 "Cannot squeeze dynamic dimension ",
700 ". Only dimensions of size 1 can be squeezed.");
702 sizes.erase(sizes.begin() + d);
703 return Derived(squeeze(d), sizes,
intmd_dim());
706template <
class Derived>
714template <
class Derived>
722template <
class Derived>
732template <
class Derived>
736 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
737 at::Tensor t = *
this;
739 for (
Size i = 0; i < n; ++i)
742 B.insert(B.begin() + d, n, 1);
746template <
class Derived>
750 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
751 at::Tensor t = *
this;
753 for (
Size i = 0; i < n; ++i)
758template <
class Derived>
762 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
763 at::Tensor t = *
this;
765 for (
Size i = 0; i < n; ++i)
770template <
class Derived>
780template <
class Derived>
788 std::swap(sizes[d1], sizes[d2]);
790 return Derived(transpose(d1, d2), sizes,
intmd_dim());
793template <
class Derived>
802template <
class Derived>
811template <
class Derived>
816 "batch_transpose is only supported when there are no intermediate dimensions.");
820template <
class Derived>
828 auto from = sizes.begin() + old_dim;
829 auto to = sizes.begin() + new_dim;
831 std::rotate(from, from + 1,
to + 1);
833 std::rotate(
to, from, from + 1);
835 return Derived(movedim(old_dim, new_dim), sizes,
intmd_dim());
838template <
class Derived>
847template <
class Derived>
856template <
class Derived>
861 "batch_movedim is only supported when there are no intermediate dimensions.");
865template <
class Derived>
879 return Derived(flatten(start_dim, end_dim), {n},
intmd_dim());
882template <
class Derived>
895 return Derived(flatten(start_dim, end_dim),
dynamic_sizes(), 1);
898template <
class Derived>
914template <
class Derived>
925 return Derived(flatten(0,
batch_dim() - 1), {n}, 0);
928template <
class Derived>
941template <
class Derived>
Derived batch_squeeze(Size d) const
Definition TensorBaseImpl.h:724
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:930
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:315
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:858
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:849
Derived batch_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:296
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:760
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:804
Derived batch_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:335
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:305
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:840
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:601
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:608
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:734
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:943
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:509
Derived contiguous() const
Definition TensorBaseImpl.h:126
Derived batch_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:813
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:443
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:594
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:587
Derived intmd_flatten(Size start_dim=0, Size end_dim=-1) const
Definition TensorBaseImpl.h:884
neml2::Tensor base_flatten(Size start_dim=0, Size end_dim=-1) const
Definition TensorBaseImpl.h:900
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership).
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:426
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:464
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:748
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:822
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:622
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:692
Derived dynamic_flatten(Size start_dim=0, Size end_dim=-1) const
Definition TensorBaseImpl.h:867
Derived batch_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:772
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:383
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:795
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:401
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:573
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:916
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:325
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:674
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:708
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:345
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:363
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:782
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:716
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:580
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:640
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:419
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:484
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition indexing.h:38
Definition BufferStore.h:43
TensorShape add_shapes(const S &...)
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition jit.h:86
TraceableSize traceable_numel(const TraceableTensorShape &shape)
Get the number of elements in a tensor shape.
Definition DiagnosticsInterface.h:31
void neml_assert_dbg(bool assertion, Args &&... args)
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
at::Tensor ATensor
Definition types.h:42
int64_t Size
Definition types.h:71
c10::Scalar CScalar
Definition types.h:43
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73
void neml_assert(bool assertion, Args &&... args)
Definition assertions.h:47
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const