32#include <torch/csrc/jit/frontend/tracer.h>
34#include "neml2/tensors/TraceableTensorShape.h"
35#include "neml2/tensors/Tensor.h"
36#include "neml2/tensors/TensorBase.h"
37#include "neml2/tensors/Scalar.h"
38#include "neml2/tensors/jit.h"
39#include "neml2/tensors/shape_utils.h"
40#include "neml2/misc/assertions.h"
44using namespace torch::jit;
49template <
class Derived>
52 _dynamic_sizes(utils::extract_traceable_sizes(tensor, 0, dynamic_dim)),
58template <
class Derived>
63 _dynamic_sizes(std::move(dynamic_shape)),
69template <
class Derived>
76 " is not sufficient for the requested number of dynamic dimensions (",
78 ") and intermediate dimensions (",
81 neml_assert(dynamic_sizes() == sizes().slice(0, dynamic_dim()),
84 " is incompatible with dynamic shape ",
86 ". The leading dimensions must match.");
89template <
class Derived>
93 return Derived(at::empty_like(other), other.dynamic_sizes(), other.intmd_dim());
96template <
class Derived>
100 return Derived(at::zeros_like(other), other.dynamic_sizes(), other.intmd_dim());
103template <
class Derived>
107 return Derived(at::ones_like(other), other.dynamic_sizes(), other.intmd_dim());
110template <
class Derived>
114 return Derived(at::full_like(other, init), other.dynamic_sizes(), other.intmd_dim());
117template <
class Derived>
121 return Derived(at::rand_like(other), other.dynamic_sizes(), other.intmd_dim());
124template <
class Derived>
128 return Derived(ATensor::contiguous(), dynamic_sizes(), intmd_dim());
131template <
class Derived>
135 return Derived(ATensor::clone(), dynamic_sizes(), intmd_dim());
138template <
class Derived>
142 return Derived(ATensor::detach(), dynamic_sizes(), intmd_dim());
145template <
class Derived>
149 return Derived(ATensor::to(options), dynamic_sizes(), intmd_dim());
152template <
class Derived>
159template <
class Derived>
163 return dim() - batch_dim();
166template <
class Derived>
170 return static_cast<Size>(_dynamic_sizes.size());
173template <
class Derived>
177 return dim() - dynamic_dim();
180template <
class Derived>
187template <
class Derived>
194template <
class Derived>
198 return sizes().slice(batch_dim());
201template <
class Derived>
205 return _dynamic_sizes;
208template <
class Derived>
212 return sizes().
slice(dynamic_dim());
215template <
class Derived>
219 return sizes().
slice(dynamic_dim(), intmd_dim());
222template <
class Derived>
227 if (i < dynamic_dim())
228 return _dynamic_sizes[i];
232template <
class Derived>
240template <
class Derived>
245 return _dynamic_sizes[i];
248template <
class Derived>
256template <
class Derived>
264template <
class Derived>
269 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
270 auto res = this->index(indices_vec);
271 return Derived(res, res.dim() - static_dim(), intmd_dim());
274template <
class Derived>
279 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
280 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
281 auto res = this->index(indices_vec);
282 return Derived(res, dynamic_sizes(), res.dim() - dynamic_dim() - base_dim());
285template <
class Derived>
290 indices2.insert(indices2.end(), indices.begin(), indices.end());
291 return neml2::Tensor(this->index(indices2), dynamic_sizes(), intmd_dim());
294template <
class Derived>
299 auto res = this->slice(
300 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
301 return Derived(res, res.dim() - static_dim(), intmd_dim());
304template <
class Derived>
309 auto res = this->slice(
310 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
311 return Derived(res, dynamic_sizes(), intmd_dim());
314template <
class Derived>
319 auto res = this->slice(
320 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
324template <
class Derived>
329 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
330 this->index_put_(indices_vec, other);
333template <
class Derived>
338 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
339 this->index_put_(indices_vec, v);
342template <
class Derived>
347 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
348 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
349 this->index_put_(indices_vec, other);
352template <
class Derived>
357 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
358 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
359 this->index_put_(indices_vec, v);
362template <
class Derived>
367 indices2.insert(indices2.end(), indices.begin(), indices.end());
368 this->index_put_(indices2, other);
371template <
class Derived>
376 indices2.insert(indices2.end(), indices.begin(), indices.end());
377 this->index_put_(indices2, v);
380template <
class Derived>
384 return Derived(ATensor::variable_data(), dynamic_sizes(), intmd_dim());
387template <
class Derived>
393 net.insert(net.end(), static_dim(), -1);
396 if (jit::tracer::isTracing())
397 for (std::size_t i = 0; i < shape.size(); ++i)
398 if (
const auto *
const si = shape[i].traceable())
399 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
401 return Derived(expand(net), shape, intmd_dim());
404template <
class Derived>
408 if (intmd_sizes() == shape)
413 "Invalid intermediate shape to expand. Expected at least ",
416 auto tmp = intmd_unsqueeze(0, shape.size() - intmd_dim());
420 net.insert(net.end(), shape.begin(), shape.end());
421 net.insert(net.end(), base_dim(), -1);
422 return Derived(tmp.expand(net), dynamic_sizes(),
Size(shape.size()));
425template <
class Derived>
429 if (base_sizes() == shape)
434 "Invalid base shape to expand. Expected at least ",
437 auto tmp = base_unsqueeze(0, shape.size() - base_dim());
441 net.insert(net.end(), shape.begin(), shape.end());
442 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), intmd_dim());
445template <
class Derived>
452 "Invalid intermediate shape to expand. Expected at least ",
455 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
459 net.insert(net.end(), base_dim(), -1);
462 if (jit::tracer::isTracing())
463 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
464 if (
const auto *
const si = dynamic_shape[i].traceable())
465 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"size", net.size(), i, *si);
467 return Derived(tmp.expand(net), dynamic_shape, tmp.intmd_dim());
470template <
class Derived>
475 if (static_sizes() == net)
480 "Invalid intermediate shape to expand. Expected at least ",
484 "Invalid base shape to expand. Expected at least ",
487 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
488 tmp = tmp.base_unsqueeze(0, base_shape.size() - base_dim());
491 net.insert(net.begin(), dynamic_dim(), -1);
492 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), tmp.intmd_dim());
495template <
class Derived>
499 if (dynamic_size(d) == size)
500 return Derived(*
this);
503 auto shape = dynamic_sizes();
505 return dynamic_expand(shape);
508template <
class Derived>
512 if (intmd_size(d) == size)
518 return Derived(expand(net), dynamic_sizes(), intmd_dim());
521template <
class Derived>
525 if (base_size(d) == size)
531 return neml2::Tensor(expand(net), dynamic_sizes(), intmd_dim());
534template <
class Derived>
541template <
class Derived>
548template <
class Derived>
555template <
class Derived>
562template <
class Derived>
569template <
class Derived>
574 if (jit::tracer::isTracing())
575 for (std::size_t i = 0; i < shape.size(); ++i)
576 if (
const auto *
const si = shape[i].traceable())
577 jit::tracer::ArgumentStash::stashIntArrayRefElem(
578 "shape", shape.size() + static_dim(), i, *si);
583template <
class Derived>
587 auto intmd_dim =
Size(shape.size());
590 if (jit::tracer::isTracing())
591 for (
Size i = 0; i < dynamic_dim(); ++i)
592 if (
const auto *
const si = dynamic_size(i).traceable())
593 jit::tracer::ArgumentStash::stashIntArrayRefElem(
594 "shape", dynamic_dim() + intmd_dim + base_dim(), i, *si);
601template <
class Derived>
606 if (jit::tracer::isTracing())
607 for (
Size i = 0; i < dynamic_dim(); ++i)
608 if (
const auto *
const si = dynamic_size(i).traceable())
609 jit::tracer::ArgumentStash::stashIntArrayRefElem(
610 "shape", batch_dim() + shape.size(), i, *si);
613 reshape(
utils::add_shapes(batch_sizes().concrete(), shape)), dynamic_sizes(), intmd_dim());
616template <
class Derived>
621 auto intmd_dim =
Size(intmd_shape.size());
624 if (jit::tracer::isTracing())
625 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
626 if (
const auto *
const si = dynamic_shape[i].traceable())
627 jit::tracer::ArgumentStash::stashIntArrayRefElem(
628 "shape", dynamic_shape.size() + intmd_dim + base_dim(), i, *si);
635template <
class Derived>
639 auto intmd_dim =
Size(intmd_shape.size());
642 if (jit::tracer::isTracing())
643 for (
Size i = 0; i < dynamic_dim(); ++i)
644 if (
const auto *
const si = dynamic_size(i).traceable())
645 jit::tracer::ArgumentStash::stashIntArrayRefElem(
646 "shape", dynamic_dim() + intmd_dim + base_shape.size(), i, *si);
648 return Derived(reshape(
utils::add_shapes(dynamic_sizes().concrete(), intmd_shape, base_shape)),
653template <
class Derived>
659 "Cannot squeeze dynamic dimension ",
663 ". Only dimensions of size 1 can be squeezed.");
664 auto sizes = dynamic_sizes();
665 sizes.erase(sizes.begin() + d);
666 return Derived(squeeze(d), sizes, intmd_dim());
669template <
class Derived>
674 return Derived(squeeze(d), dynamic_sizes(), intmd_dim() - 1);
677template <
class Derived>
682 return neml2::Tensor(squeeze(d), dynamic_sizes(), intmd_dim());
685template <
class Derived>
689 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
690 at::Tensor t = *
this;
692 for (
Size i = 0; i < n; ++i)
694 auto B = dynamic_sizes();
695 B.insert(B.begin() + d, n, 1);
696 return Derived(t, B, intmd_dim());
699template <
class Derived>
703 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
704 at::Tensor t = *
this;
706 for (
Size i = 0; i < n; ++i)
708 return Derived(t, dynamic_sizes(), intmd_dim() + n);
711template <
class Derived>
715 neml_assert(n >= 0,
"Number of dimensions to unsqueeze must be non-negative.");
716 at::Tensor t = *
this;
718 for (
Size i = 0; i < n; ++i)
723template <
class Derived>
730 auto sizes = dynamic_sizes();
731 std::swap(sizes[d1], sizes[d2]);
733 return Derived(transpose(d1, d2), sizes, intmd_dim());
736template <
class Derived>
742 return Derived(transpose(d1, d2), dynamic_sizes(), intmd_dim());
745template <
class Derived>
751 return neml2::Tensor(transpose(d1, d2), dynamic_sizes(), intmd_dim());
754template <
class Derived>
761 auto sizes = dynamic_sizes();
762 auto from = sizes.begin() + old_dim;
763 auto to = sizes.begin() + new_dim;
765 std::rotate(from, from + 1, to + 1);
767 std::rotate(to, from, from + 1);
769 return Derived(movedim(old_dim, new_dim), sizes, intmd_dim());
772template <
class Derived>
778 return Derived(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
781template <
class Derived>
787 return neml2::Tensor(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
790template <
class Derived>
794 if (dynamic_dim() == 1)
798 if (
const auto *
const nt = n.traceable())
799 jit::tracer::ArgumentStash::stashIntArrayRefElem(
"shape", 1 + static_dim(), 0, *nt);
802 return Derived(reshape(sizes), {n}, intmd_dim());
805template <
class Derived>
809 if (intmd_dim() == 1)
814template <
class Derived>
823template <
class Derived>
827 if (intmd_dim() == 0 && dynamic_dim() == 1)
832template <
class Derived>
836 if (intmd_dim() == 0 && base_dim() == 1)
838 return static_reshape({},
utils::numel(static_sizes()));
841template <
class Derived>
845 return Derived(-
ATensor(*
this), dynamic_sizes(), intmd_dim());
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:834
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:306
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
neml2::Tensor base_flatten() const
Definition TensorBaseImpl.h:816
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:447
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:783
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:713
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:747
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
Derived dynamic_flatten() const
Definition TensorBaseImpl.h:792
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:296
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:774
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:564
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:571
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:687
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:843
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:472
Derived contiguous() const
Definition TensorBaseImpl.h:126
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:406
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:557
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:550
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership)
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:389
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:427
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_flatten() const
Definition TensorBaseImpl.h:807
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:701
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:756
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:585
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:364
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:738
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:536
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:825
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:316
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:637
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:671
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:326
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:618
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:344
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:725
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:679
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:543
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:603
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:382
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition indexing.h:38
Definition BufferStore.h:43
TensorShape add_shapes(const S &...)
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Definition shape_utils.cxx:49
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
Definition shape_utils.cxx:34
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
Definition shape_utils.cxx:64
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition jit.h:86
TraceableSize traceable_numel(const TraceableTensorShape &shape)
Get the number of elements in a tensor shape.
Definition jit.cxx:61
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
at::Tensor ATensor
Definition types.h:38
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
void neml_assert(bool assertion, Args &&... args)
Definition assertions.h:47
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:71
TraceableTensorShape slice(std::size_t N, std::size_t M) const
Slice the shape.
Definition TraceableTensorShape.cxx:59