32#include <torch/csrc/jit/frontend/tracer.h>
34#include "neml2/tensors/TraceableTensorShape.h"
35#include "neml2/tensors/Tensor.h"
36#include "neml2/tensors/TensorBase.h"
37#include "neml2/tensors/Scalar.h"
38#include "neml2/tensors/jit.h"
39#include "neml2/tensors/shape_utils.h"
40#include "neml2/misc/assertions.h"
44using namespace torch::jit;
49template <
class Derived>
52 _dynamic_sizes(utils::extract_traceable_sizes(tensor, 0, dynamic_dim)),
58template <
class Derived>
63 _dynamic_sizes(std::move(dynamic_shape)),
69template <
class Derived>
76 " is not sufficient for the requested number of dynamic dimensions (",
78 ") and intermediate dimensions (",
81 neml_assert(dynamic_sizes() == sizes().slice(0, dynamic_dim()),
84 " is incompatible with dynamic shape ",
86 ". The leading dimensions must match.");
89template <
class Derived>
93 return Derived(at::empty_like(other), other.dynamic_sizes(), other.intmd_dim());
96template <
class Derived>
100 return Derived(at::zeros_like(other), other.dynamic_sizes(), other.intmd_dim());
103template <
class Derived>
107 return Derived(at::ones_like(other), other.dynamic_sizes(), other.intmd_dim());
110template <
class Derived>
114 return Derived(at::full_like(other, init), other.dynamic_sizes(), other.intmd_dim());
117template <
class Derived>
121 return Derived(at::rand_like(other), other.dynamic_sizes(), other.intmd_dim());
124template <
class Derived>
128 return Derived(ATensor::contiguous(), dynamic_sizes(), intmd_dim());
131template <
class Derived>
135 return Derived(ATensor::clone(), dynamic_sizes(), intmd_dim());
138template <
class Derived>
142 return Derived(ATensor::detach(), dynamic_sizes(), intmd_dim());
145template <
class Derived>
149 return Derived(ATensor::to(options), dynamic_sizes(), intmd_dim());
152template <
class Derived>
159template <
class Derived>
163 return dim() - batch_dim();
166template <
class Derived>
170 return static_cast<Size>(_dynamic_sizes.size());
173template <
class Derived>
177 return dim() - dynamic_dim();
180template <
class Derived>
187template <
class Derived>
194template <
class Derived>
198 return sizes().slice(batch_dim());
201template <
class Derived>
205 return _dynamic_sizes;
208template <
class Derived>
212 return sizes().
slice(dynamic_dim());
215template <
class Derived>
219 return sizes().
slice(dynamic_dim(), intmd_dim());
222template <
class Derived>
227 if (i < dynamic_dim())
228 return _dynamic_sizes[i];
232template <
class Derived>
240template <
class Derived>
245 return _dynamic_sizes[i];
248template <
class Derived>
256template <
class Derived>
264template <
class Derived>
269 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
270 auto res = this->index(indices_vec);
271 return Derived(res, res.dim() - static_dim(), intmd_dim());
274template <
class Derived>
279 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
280 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
281 auto res = this->index(indices_vec);
282 return Derived(res, dynamic_sizes(), res.dim() - dynamic_dim() - base_dim());
285template <
class Derived>
290 indices2.insert(indices2.end(), indices.begin(), indices.end());
291 return neml2::Tensor(this->index(indices2), dynamic_sizes(), intmd_dim());
294template <
class Derived>
299 "batch_index is only supported when there are no intermediate dimensions.");
300 return dynamic_index(indices);
303template <
class Derived>
308 auto res = this->slice(
309 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
310 return Derived(res, res.dim() - static_dim(), intmd_dim());
313template <
class Derived>
318 auto res = this->slice(
319 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
320 return Derived(res, dynamic_sizes(), intmd_dim());
323template <
class Derived>
328 auto res = this->slice(
329 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
333template <
class Derived>
338 if (d < dynamic_dim())
339 return dynamic_slice(d, index);
340 return intmd_slice(d - dynamic_dim(), index);
343template <
class Derived>
348 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
349 this->index_put_(indices_vec, other);
352template <
class Derived>
357 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
358 this->index_put_(indices_vec, v);
361template <
class Derived>
366 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
367 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
368 this->index_put_(indices_vec, other);
371template <
class Derived>
376 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
377 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
378 this->index_put_(indices_vec, v);
381template <
class Derived>
386 indices2.insert(indices2.end(), indices.begin(), indices.end());
387 this->index_put_(indices2, other);
390template <
class Derived>
395 indices2.insert(indices2.end(), indices.begin(), indices.end());
396 this->index_put_(indices2, v);
399template <
class Derived>
404 "batch_index_put_ is only supported when there are no intermediate dimensions.");
405 dynamic_index_put_(indices, other);
408template <
class Derived>
413 "batch_index_put_ is only supported when there are no intermediate dimensions.");
414 dynamic_index_put_(indices, v);
417template <
class Derived>
421 return Derived(ATensor::variable_data(), dynamic_sizes(), intmd_dim());
424template <
class Derived>
430 net.insert(net.end(), static_dim(), -1);
433 if (jit::tracer::isTracing())
434 for (std::size_t i = 0; i < shape.size(); ++i)
435 if (
const auto *
const si = shape[i].traceable())
436 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
438 return Derived(expand(net), shape, intmd_dim());
441template <
class Derived>
445 if (intmd_sizes() == shape)
450 "Invalid intermediate shape to expand. Expected at least ",
453 auto tmp = intmd_unsqueeze(0, shape.size() - intmd_dim());
457 net.insert(net.end(), shape.begin(), shape.end());
458 net.insert(net.end(), base_dim(), -1);
459 return Derived(tmp.expand(net), dynamic_sizes(),
Size(shape.size()));
462template <
class Derived>
466 if (base_sizes() == shape)
471 "Invalid base shape to expand. Expected at least ",
474 auto tmp = base_unsqueeze(0, shape.size() - base_dim());
478 net.insert(net.end(), shape.begin(), shape.end());
479 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), intmd_dim());
482template <
class Derived>
489 "Invalid intermediate shape to expand. Expected at least ",
492 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
496 net.insert(net.end(), base_dim(), -1);
499 if (jit::tracer::isTracing())
500 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
501 if (
const auto *
const si = dynamic_shape[i].traceable())
502 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
504 return Derived(tmp.expand(net), dynamic_shape, tmp.intmd_dim());
507template <
class Derived>
512 if (static_sizes() == net)
517 "Invalid intermediate shape to expand. Expected at least ",
521 "Invalid base shape to expand. Expected at least ",
524 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
525 tmp = tmp.base_unsqueeze(0, base_shape.size() - base_dim());
528 net.insert(net.begin(), dynamic_dim(), -1);
529 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), tmp.intmd_dim());
532template <
class Derived>
536 if (dynamic_size(d) == size)
537 return Derived(*
this);
540 auto shape = dynamic_sizes();
542 return dynamic_expand(shape);
545template <
class Derived>
549 if (intmd_size(d) == size)
555 return Derived(expand(net), dynamic_sizes(), intmd_dim());
558template <
class Derived>
562 if (base_size(d) == size)
568 return neml2::Tensor(expand(net), dynamic_sizes(), intmd_dim());
571template <
class Derived>
578template <
class Derived>
585template <
class Derived>
592template <
class Derived>
599template <
class Derived>
606template <
class Derived>
611 if (jit::tracer::isTracing())
612 for (std::size_t i = 0; i < shape.size(); ++i)
613 if (
const auto *
const si = shape[i].traceable())
614 jit::tracer::ArgumentStash::stashIntArrayRefElem(
615 "shape", shape.size() + static_dim(), i, *si);
620template <
class Derived>
624 auto intmd_dim =
Size(shape.size());
627 if (jit::tracer::isTracing())
628 for (
Size i = 0; i < dynamic_dim(); ++i)
629 if (
const auto *
const si = dynamic_size(i).traceable())
630 jit::tracer::ArgumentStash::stashIntArrayRefElem(
631 "shape", dynamic_dim() + intmd_dim + base_dim(), i, *si);
638template <
class Derived>
643 if (jit::tracer::isTracing())
644 for (
Size i = 0; i < dynamic_dim(); ++i)
645 if (
const auto *
const si = dynamic_size(i).traceable())
646 jit::tracer::ArgumentStash::stashIntArrayRefElem(
647 "shape", batch_dim() + shape.size(), i, *si);
650 reshape(
utils::add_shapes(batch_sizes().concrete(), shape)), dynamic_sizes(), intmd_dim());
653template <
class Derived>
658 auto intmd_dim =
Size(intmd_shape.size());
661 if (jit::tracer::isTracing())
662 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
663 if (
const auto *
const si = dynamic_shape[i].traceable())
664 jit::tracer::ArgumentStash::stashIntArrayRefElem(
665 "shape", dynamic_shape.size() + intmd_dim + base_dim(), i, *si);
672template <
class Derived>
676 auto intmd_dim =
Size(intmd_shape.size());
679 if (jit::tracer::isTracing())
680 for (
Size i = 0; i < dynamic_dim(); ++i)
681 if (
const auto *
const si = dynamic_size(i).traceable())
682 jit::tracer::ArgumentStash::stashIntArrayRefElem(
683 "shape", dynamic_dim() + intmd_dim + base_shape.size(), i, *si);
685 return Derived(reshape(
utils::add_shapes(dynamic_sizes().concrete(), intmd_shape, base_shape)),
690template <
class Derived>
696 "Cannot squeeze dynamic dimension ",
700 ". Only dimensions of size 1 can be squeezed.");
701 auto sizes = dynamic_sizes();
702 sizes.erase(sizes.begin() + d);
703 return Derived(squeeze(d), sizes, intmd_dim());
706template <
class Derived>
711 return Derived(squeeze(d), dynamic_sizes(), intmd_dim() - 1);
714template <
class Derived>
719 return neml2::Tensor(squeeze(d), dynamic_sizes(), intmd_dim());
722template <
class Derived>
727 if (d < dynamic_dim())
728 return dynamic_squeeze(d);
729 return intmd_squeeze(d - dynamic_dim());
732template <
class Derived>
736 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
737 at::Tensor t = *
this;
739 for (
Size i = 0; i < n; ++i)
741 auto B = dynamic_sizes();
742 B.insert(B.begin() + d, n, 1);
743 return Derived(t, B, intmd_dim());
746template <
class Derived>
750 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
751 at::Tensor t = *
this;
753 for (
Size i = 0; i < n; ++i)
755 return Derived(t, dynamic_sizes(), intmd_dim() + n);
758template <
class Derived>
762 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
763 at::Tensor t = *
this;
765 for (
Size i = 0; i < n; ++i)
770template <
class Derived>
775 if (d <= dynamic_dim())
776 return dynamic_unsqueeze(d, n);
777 return intmd_unsqueeze(d - dynamic_dim(), n);
780template <
class Derived>
787 auto sizes = dynamic_sizes();
788 std::swap(sizes[d1], sizes[d2]);
790 return Derived(transpose(d1, d2), sizes, intmd_dim());
793template <
class Derived>
799 return Derived(transpose(d1, d2), dynamic_sizes(), intmd_dim());
802template <
class Derived>
808 return neml2::Tensor(transpose(d1, d2), dynamic_sizes(), intmd_dim());
811template <
class Derived>
816 "batch_transpose is only supported when there are no intermediate dimensions.");
817 return dynamic_transpose(d1, d2);
820template <
class Derived>
827 auto sizes = dynamic_sizes();
828 auto from = sizes.begin() + old_dim;
829 auto to = sizes.begin() + new_dim;
831 std::rotate(from, from + 1, to + 1);
833 std::rotate(to, from, from + 1);
835 return Derived(movedim(old_dim, new_dim), sizes, intmd_dim());
838template <
class Derived>
844 return Derived(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
847template <
class Derived>
853 return neml2::Tensor(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
856template <
class Derived>
861 "batch_movedim is only supported when there are no intermediate dimensions.");
862 return dynamic_movedim(old_dim, new_dim);
865template <
class Derived>
869 if (dynamic_dim() == 1)
873 if (
const auto *
const nt = n.traceable())
874 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"shape", 1 + static_dim(), 0, *nt);
877 return Derived(reshape(sizes), {n}, intmd_dim());
880template <
class Derived>
884 if (intmd_dim() == 1)
889template <
class Derived>
898template <
class Derived>
902 if (intmd_dim() == 0 && dynamic_dim() == 1)
907template <
class Derived>
911 if (intmd_dim() == 0 && base_dim() == 1)
913 return static_reshape({},
utils::numel(static_sizes()));
916template <
class Derived>
920 return Derived(-
ATensor(*
this), dynamic_sizes(), intmd_dim());
Derived batch_squeeze(Size d) const
Definition TensorBaseImpl.h:724
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:909
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:315
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
neml2::Tensor base_flatten() const
Definition TensorBaseImpl.h:891
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:858
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:849
Derived batch_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:296
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:760
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:804
Derived batch_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:335
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
Derived dynamic_flatten() const
Definition TensorBaseImpl.h:867
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:305
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:840
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:601
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:608
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:734
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:918
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:509
Derived contiguous() const
Definition TensorBaseImpl.h:126
Derived batch_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:813
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:443
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:594
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:587
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership)
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:426
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:464
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_flatten() const
Definition TensorBaseImpl.h:882
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:748
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:822
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:622
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:692
Derived batch_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:772
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:383
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:795
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:401
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:573
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:900
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:325
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:674
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:708
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:345
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:363
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:782
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:716
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:580
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:640
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:419
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:484
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition indexing.h:38
Definition BufferStore.h:43
TensorShape add_shapes(const S &...)
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Definition shape_utils.cxx:49
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
Definition shape_utils.cxx:34
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
Definition shape_utils.cxx:64
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition jit.h:86
TraceableSize traceable_numel(const TraceableTensorShape &shape)
Get the number of elements in a tensor shape.
Definition jit.cxx:61
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:69
at::Tensor ATensor
Definition types.h:41
int64_t Size
Definition types.h:68
c10::Scalar CScalar
Definition types.h:42
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:70
void neml_assert(bool assertion, Args &&... args)
Definition assertions.h:47
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:71
TraceableTensorShape slice(std::size_t N, std::size_t M) const
Slice the shape.
Definition TraceableTensorShape.cxx:59