NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBaseImpl.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
30#pragma once
31
32#include <torch/csrc/jit/frontend/tracer.h>
33
34#include "neml2/tensors/TraceableTensorShape.h"
35#include "neml2/tensors/Tensor.h"
36#include "neml2/tensors/TensorBase.h"
37#include "neml2/tensors/Scalar.h"
38#include "neml2/tensors/jit.h"
39#include "neml2/tensors/shape_utils.h"
40#include "neml2/misc/assertions.h"
41
42namespace neml2::jit
43{
44using namespace torch::jit;
45}
46
47namespace neml2
48{
49template <class Derived>
50TensorBase<Derived>::TensorBase(const ATensor & tensor, Size dynamic_dim, Size intmd_dim)
51 : ATensor(tensor),
52 _dynamic_sizes(utils::extract_traceable_sizes(tensor, 0, dynamic_dim)),
53 _intmd_dim(intmd_dim)
54{
56}
57
58template <class Derived>
60 TraceableTensorShape dynamic_shape,
61 Size intmd_dim)
62 : ATensor(tensor),
63 _dynamic_sizes(std::move(dynamic_shape)),
64 _intmd_dim(intmd_dim)
65{
67}
68
69template <class Derived>
70void
72{
73 neml_assert(dim() >= dynamic_dim() + intmd_dim(),
74 "Tensor dimension ",
75 dim(),
76 " is not sufficient for the requested number of dynamic dimensions (",
77 dynamic_dim(),
78 ") and intermediate dimensions (",
79 intmd_dim(),
80 ")");
81 neml_assert(dynamic_sizes() == sizes().slice(0, dynamic_dim()),
82 "Tensor of shape ",
83 sizes(),
84 " is incompatible with dynamic shape ",
85 dynamic_sizes(),
86 ". The leading dimensions must match.");
87}
88
89template <class Derived>
90Derived
91TensorBase<Derived>::empty_like(const Derived & other)
92{
93 return Derived(at::empty_like(other), other.dynamic_sizes(), other.intmd_dim());
94}
95
96template <class Derived>
97Derived
98TensorBase<Derived>::zeros_like(const Derived & other)
99{
100 return Derived(at::zeros_like(other), other.dynamic_sizes(), other.intmd_dim());
101}
103template <class Derived>
104Derived
105TensorBase<Derived>::ones_like(const Derived & other)
107 return Derived(at::ones_like(other), other.dynamic_sizes(), other.intmd_dim());
108}
110template <class Derived>
111Derived
112TensorBase<Derived>::full_like(const Derived & other, const CScalar & init)
113{
114 return Derived(at::full_like(other, init), other.dynamic_sizes(), other.intmd_dim());
115}
116
117template <class Derived>
118Derived
119TensorBase<Derived>::rand_like(const Derived & other)
121 return Derived(at::rand_like(other), other.dynamic_sizes(), other.intmd_dim());
123
124template <class Derived>
125Derived
127{
128 return Derived(ATensor::contiguous(), dynamic_sizes(), intmd_dim());
129}
130
131template <class Derived>
132Derived
134{
135 return Derived(ATensor::clone(), dynamic_sizes(), intmd_dim());
137
138template <class Derived>
139Derived
141{
142 return Derived(ATensor::detach(), dynamic_sizes(), intmd_dim());
143}
144
145template <class Derived>
146Derived
148{
149 return Derived(ATensor::to(options), dynamic_sizes(), intmd_dim());
150}
151
152template <class Derived>
153Size
156 return dynamic_dim() + intmd_dim();
159template <class Derived>
160Size
162{
163 return dim() - batch_dim();
166template <class Derived>
169{
170 return static_cast<Size>(_dynamic_sizes.size());
171}
172
173template <class Derived>
177 return dim() - dynamic_dim();
179
180template <class Derived>
181Size
183{
184 return _intmd_dim;
185}
186
187template <class Derived>
191 return utils::add_traceable_shapes(dynamic_sizes(), intmd_sizes());
192}
193
194template <class Derived>
198 return sizes().slice(batch_dim());
199}
200
201template <class Derived>
205 return _dynamic_sizes;
208template <class Derived>
211{
212 return sizes().slice(dynamic_dim());
213}
215template <class Derived>
219 return sizes().slice(dynamic_dim(), intmd_dim());
222template <class Derived>
225{
226 i = utils::normalize_dim(i, 0, batch_dim());
227 if (i < dynamic_dim())
228 return _dynamic_sizes[i];
229 return size(i);
231
232template <class Derived>
233Size
236 i = utils::normalize_dim(i, batch_dim(), dim());
237 return size(i);
240template <class Derived>
241const TraceableSize &
243{
244 i = utils::normalize_dim(i, 0, dynamic_dim());
245 return _dynamic_sizes[i];
248template <class Derived>
251{
252 i = utils::normalize_dim(i, dynamic_dim(), dim());
253 return size(i);
256template <class Derived>
259{
260 i = utils::normalize_dim(i, dynamic_dim(), batch_dim());
261 return size(i);
264template <class Derived>
265Derived
267{
268 indexing::TensorIndices indices_vec(indices);
269 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
270 auto res = this->index(indices_vec);
271 return Derived(res, res.dim() - static_dim(), intmd_dim());
274template <class Derived>
275Derived
277{
278 indexing::TensorIndices indices_vec(dynamic_dim(), indexing::Slice());
279 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
280 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
281 auto res = this->index(indices_vec);
282 return Derived(res, dynamic_sizes(), res.dim() - dynamic_dim() - base_dim());
283}
284
285template <class Derived>
289 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
290 indices2.insert(indices2.end(), indices.begin(), indices.end());
291 return neml2::Tensor(this->index(indices2), dynamic_sizes(), intmd_dim());
292}
293
294template <class Derived>
295Derived
297{
298 neml_assert_dbg(_intmd_dim == 0,
299 "batch_index is only supported when there are no intermediate dimensions.");
300 return dynamic_index(indices);
301}
303template <class Derived>
304Derived
305TensorBase<Derived>::dynamic_slice(Size d, const indexing::Slice & index) const
306{
307 d = utils::normalize_dim(d, 0, dynamic_dim());
308 auto res = this->slice(
309 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
310 return Derived(res, res.dim() - static_dim(), intmd_dim());
311}
312
313template <class Derived>
314Derived
315TensorBase<Derived>::intmd_slice(Size d, const indexing::Slice & index) const
316{
317 d = utils::normalize_dim(d, dynamic_dim(), batch_dim());
318 auto res = this->slice(
319 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
320 return Derived(res, dynamic_sizes(), intmd_dim());
321}
322
323template <class Derived>
325TensorBase<Derived>::base_slice(Size d, const indexing::Slice & index) const
326{
327 d = utils::normalize_dim(d, batch_dim(), dim());
328 auto res = this->slice(
329 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
330 return neml2::Tensor(res, dynamic_sizes(), intmd_dim());
331}
332
333template <class Derived>
334Derived
335TensorBase<Derived>::batch_slice(Size d, const indexing::Slice & index) const
336{
337 d = utils::normalize_dim(d, 0, batch_dim());
338 if (d < dynamic_dim())
339 return dynamic_slice(d, index);
340 return intmd_slice(d - dynamic_dim(), index);
341}
342
343template <class Derived>
344void
346{
347 indexing::TensorIndices indices_vec(indices);
348 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
349 this->index_put_(indices_vec, other);
350}
351
352template <class Derived>
353void
355{
356 indexing::TensorIndices indices_vec(indices);
357 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
358 this->index_put_(indices_vec, v);
359}
360
361template <class Derived>
362void
364{
365 indexing::TensorIndices indices_vec(dynamic_dim(), indexing::Slice());
366 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
367 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
368 this->index_put_(indices_vec, other);
369}
370
371template <class Derived>
372void
374{
375 indexing::TensorIndices indices_vec(dynamic_dim(), indexing::Slice());
376 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
377 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
378 this->index_put_(indices_vec, v);
379}
380
381template <class Derived>
382void
384{
385 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
386 indices2.insert(indices2.end(), indices.begin(), indices.end());
387 this->index_put_(indices2, other);
388}
389
390template <class Derived>
391void
393{
394 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
395 indices2.insert(indices2.end(), indices.begin(), indices.end());
396 this->index_put_(indices2, v);
397}
398
399template <class Derived>
400void
402{
403 neml_assert_dbg(_intmd_dim == 0,
404 "batch_index_put_ is only supported when there are no intermediate dimensions.");
405 dynamic_index_put_(indices, other);
406}
407
408template <class Derived>
409void
411{
412 neml_assert_dbg(_intmd_dim == 0,
413 "batch_index_put_ is only supported when there are no intermediate dimensions.");
414 dynamic_index_put_(indices, v);
415}
416
417template <class Derived>
418Derived
420{
421 return Derived(ATensor::variable_data(), dynamic_sizes(), intmd_dim());
422}
423
424template <class Derived>
425Derived
427{
428 // We don't want to touch the other dimensions, so put -1 for them.
429 auto net = shape.concrete();
430 net.insert(net.end(), static_dim(), -1);
431
432 // Record the dynamic sizes in the traced graph if we are tracing
433 if (jit::tracer::isTracing())
434 for (std::size_t i = 0; i < shape.size(); ++i)
435 if (const auto * const si = shape[i].traceable())
436 jit::tracer::ArgumentStash::stashIntArrayRefElem("size", net.size(), i, *si);
437
438 return Derived(expand(net), shape, intmd_dim());
439}
440
441template <class Derived>
442Derived
444{
445 if (intmd_sizes() == shape)
446 return *this;
447
448 // Unsqueeze missing dimensions
449 neml_assert_dbg(Size(shape.size()) >= intmd_dim(),
450 "Invalid intermediate shape to expand. Expected at least ",
451 intmd_dim(),
452 " dimensions.");
453 auto tmp = intmd_unsqueeze(0, shape.size() - intmd_dim());
454
455 // We don't want to touch the other dimensions, so put -1 for them.
456 TensorShape net(dynamic_dim(), -1);
457 net.insert(net.end(), shape.begin(), shape.end());
458 net.insert(net.end(), base_dim(), -1);
459 return Derived(tmp.expand(net), dynamic_sizes(), Size(shape.size()));
460}
461
462template <class Derived>
465{
466 if (base_sizes() == shape)
467 return *this;
468
469 // Unsqueeze missing dimensions
470 neml_assert_dbg(Size(shape.size()) >= base_dim(),
471 "Invalid base shape to expand. Expected at least ",
472 base_dim(),
473 " dimensions.");
474 auto tmp = base_unsqueeze(0, shape.size() - base_dim());
475
476 // We don't want to touch the batch dimensions, so put -1 for them.
477 TensorShape net(batch_dim(), -1);
478 net.insert(net.end(), shape.begin(), shape.end());
479 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), intmd_dim());
480}
481
482template <class Derived>
483Derived
485 TensorShapeRef intmd_shape) const
486{
487 // Unsqueeze missing dimensions
488 neml_assert_dbg(Size(intmd_shape.size()) >= intmd_dim(),
489 "Invalid intermediate shape to expand. Expected at least ",
490 intmd_dim(),
491 " dimensions.");
492 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
493
494 // We don't want to touch the other dimensions, so put -1 for them.
495 auto net = utils::add_shapes(dynamic_shape.concrete(), intmd_shape);
496 net.insert(net.end(), base_dim(), -1);
497
498 // Record the dynamic sizes in the traced graph if we are tracing
499 if (jit::tracer::isTracing())
500 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
501 if (const auto * const si = dynamic_shape[i].traceable())
502 jit::tracer::ArgumentStash::stashIntArrayRefElem("size", net.size(), i, *si);
503
504 return Derived(tmp.expand(net), dynamic_shape, tmp.intmd_dim());
505}
506
507template <class Derived>
510{
511 auto net = utils::add_shapes(intmd_shape, base_shape);
512 if (static_sizes() == net)
513 return *this;
514
515 // Unsqueeze missing dimensions
516 neml_assert_dbg(Size(intmd_shape.size()) >= intmd_dim(),
517 "Invalid intermediate shape to expand. Expected at least ",
518 intmd_dim(),
519 " dimensions.");
520 neml_assert_dbg(Size(base_shape.size()) >= base_dim(),
521 "Invalid base shape to expand. Expected at least ",
522 base_dim(),
523 " dimensions.");
524 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
525 tmp = tmp.base_unsqueeze(0, base_shape.size() - base_dim());
526
527 // We don't want to touch the other dimensions, so put -1 for them.
528 net.insert(net.begin(), dynamic_dim(), -1);
529 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), tmp.intmd_dim());
530}
531
532template <class Derived>
533Derived
535{
536 if (dynamic_size(d) == size)
537 return Derived(*this);
538
539 d = utils::normalize_dim(d, 0, dynamic_dim());
540 auto shape = dynamic_sizes();
541 shape[d] = size;
542 return dynamic_expand(shape);
543}
544
545template <class Derived>
546Derived
548{
549 if (intmd_size(d) == size)
550 return *this;
551
552 d = utils::normalize_dim(d, dynamic_dim(), batch_dim());
553 TensorShape net(dim(), -1);
554 net[d] = size;
555 return Derived(expand(net), dynamic_sizes(), intmd_dim());
556}
557
558template <class Derived>
561{
562 if (base_size(d) == size)
563 return *this;
564
565 d = utils::normalize_dim(d, batch_dim(), dim());
566 TensorShape net(dim(), -1);
567 net[d] = size;
568 return neml2::Tensor(expand(net), dynamic_sizes(), intmd_dim());
569}
570
571template <class Derived>
572Derived
574{
575 return dynamic_expand(other.dynamic_sizes());
576}
577
578template <class Derived>
579Derived
581{
582 return intmd_expand(other.intmd_sizes());
583}
584
585template <class Derived>
588{
589 return base_expand(other.base_sizes());
590}
591
592template <class Derived>
593Derived
595{
596 return batch_expand(other.dynamic_sizes(), other.intmd_sizes());
597}
598
599template <class Derived>
602{
603 return static_expand(other.intmd_sizes(), other.base_sizes());
604}
605
606template <class Derived>
607Derived
609{
610 // Record the dynamic sizes in the traced graph if we are tracing
611 if (jit::tracer::isTracing())
612 for (std::size_t i = 0; i < shape.size(); ++i)
613 if (const auto * const si = shape[i].traceable())
614 jit::tracer::ArgumentStash::stashIntArrayRefElem(
615 "shape", shape.size() + static_dim(), i, *si);
616
617 return Derived(reshape(utils::add_shapes(shape.concrete(), static_sizes())), shape, intmd_dim());
618}
619
620template <class Derived>
621Derived
623{
624 auto intmd_dim = Size(shape.size());
625
626 // Record the dynamic sizes in the traced graph if we are tracing
627 if (jit::tracer::isTracing())
628 for (Size i = 0; i < dynamic_dim(); ++i)
629 if (const auto * const si = dynamic_size(i).traceable())
630 jit::tracer::ArgumentStash::stashIntArrayRefElem(
631 "shape", dynamic_dim() + intmd_dim + base_dim(), i, *si);
632
633 return neml2::Tensor(reshape(utils::add_shapes(dynamic_sizes().concrete(), shape, base_sizes())),
634 dynamic_sizes(),
635 intmd_dim);
636}
637
638template <class Derived>
641{
642 // Record the dynamic sizes in the traced graph if we are tracing
643 if (jit::tracer::isTracing())
644 for (Size i = 0; i < dynamic_dim(); ++i)
645 if (const auto * const si = dynamic_size(i).traceable())
646 jit::tracer::ArgumentStash::stashIntArrayRefElem(
647 "shape", batch_dim() + shape.size(), i, *si);
648
649 return neml2::Tensor(
650 reshape(utils::add_shapes(batch_sizes().concrete(), shape)), dynamic_sizes(), intmd_dim());
651}
652
653template <class Derived>
654Derived
656 TensorShapeRef intmd_shape) const
657{
658 auto intmd_dim = Size(intmd_shape.size());
659
660 // Record the dynamic sizes in the traced graph if we are tracing
661 if (jit::tracer::isTracing())
662 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
663 if (const auto * const si = dynamic_shape[i].traceable())
664 jit::tracer::ArgumentStash::stashIntArrayRefElem(
665 "shape", dynamic_shape.size() + intmd_dim + base_dim(), i, *si);
666
667 return Derived(reshape(utils::add_shapes(dynamic_shape.concrete(), intmd_shape, base_sizes())),
668 dynamic_shape,
669 intmd_dim);
670}
671
672template <class Derived>
675{
676 auto intmd_dim = Size(intmd_shape.size());
677
678 // Record the dynamic sizes in the traced graph if we are tracing
679 if (jit::tracer::isTracing())
680 for (Size i = 0; i < dynamic_dim(); ++i)
681 if (const auto * const si = dynamic_size(i).traceable())
682 jit::tracer::ArgumentStash::stashIntArrayRefElem(
683 "shape", dynamic_dim() + intmd_dim + base_shape.size(), i, *si);
684
685 return Derived(reshape(utils::add_shapes(dynamic_sizes().concrete(), intmd_shape, base_shape)),
686 dynamic_sizes(),
687 intmd_dim);
688}
689
690template <class Derived>
691Derived
693{
694 d = utils::normalize_dim(d, 0, dynamic_dim());
695 neml_assert(dynamic_size(d) == 1,
696 "Cannot squeeze dynamic dimension ",
697 d,
698 " with size ",
699 dynamic_size(d),
700 ". Only dimensions of size 1 can be squeezed.");
701 auto sizes = dynamic_sizes();
702 sizes.erase(sizes.begin() + d); // Remove the squeezed dimension
703 return Derived(squeeze(d), sizes, intmd_dim());
704}
705
706template <class Derived>
707Derived
709{
710 d = utils::normalize_dim(d, dynamic_dim(), batch_dim());
711 return Derived(squeeze(d), dynamic_sizes(), intmd_dim() - 1);
712}
713
714template <class Derived>
717{
718 d = utils::normalize_dim(d, batch_dim(), dim());
719 return neml2::Tensor(squeeze(d), dynamic_sizes(), intmd_dim());
720}
721
722template <class Derived>
723Derived
725{
726 d = utils::normalize_dim(d, 0, batch_dim());
727 if (d < dynamic_dim())
728 return dynamic_squeeze(d);
729 return intmd_squeeze(d - dynamic_dim());
730}
731
732template <class Derived>
733Derived
735{
736 neml_assert(n >= 0, "Number of dimensions to unsqueeze must be non-negative.");
737 at::Tensor t = *this;
738 d = utils::normalize_itr(d, 0, dynamic_dim());
739 for (Size i = 0; i < n; ++i)
740 t = t.unsqueeze(d);
741 auto B = dynamic_sizes();
742 B.insert(B.begin() + d, n, 1);
743 return Derived(t, B, intmd_dim());
744}
745
746template <class Derived>
747Derived
749{
750 neml_assert(n >= 0, "Number of dimensions to unsqueeze must be non-negative.");
751 at::Tensor t = *this;
752 d = utils::normalize_itr(d, dynamic_dim(), batch_dim());
753 for (Size i = 0; i < n; ++i)
754 t = t.unsqueeze(d);
755 return Derived(t, dynamic_sizes(), intmd_dim() + n);
756}
757
758template <class Derived>
761{
762 neml_assert(n >= 0, "Number of dimensions to unsqueeze must be non-negative.");
763 at::Tensor t = *this;
764 d = utils::normalize_itr(d, batch_dim(), dim());
765 for (Size i = 0; i < n; ++i)
766 t = t.unsqueeze(d);
767 return neml2::Tensor(t, dynamic_sizes(), intmd_dim());
768}
769
770template <class Derived>
771Derived
773{
774 d = utils::normalize_itr(d, 0, batch_dim());
775 if (d <= dynamic_dim())
776 return dynamic_unsqueeze(d, n);
777 return intmd_unsqueeze(d - dynamic_dim(), n);
778}
779
780template <class Derived>
781Derived
783{
784 d1 = utils::normalize_dim(d1, 0, dynamic_dim());
785 d2 = utils::normalize_dim(d2, 0, dynamic_dim());
786
787 auto sizes = dynamic_sizes();
788 std::swap(sizes[d1], sizes[d2]);
789
790 return Derived(transpose(d1, d2), sizes, intmd_dim());
791}
792
793template <class Derived>
794Derived
796{
797 d1 = utils::normalize_dim(d1, dynamic_dim(), batch_dim());
798 d2 = utils::normalize_dim(d2, dynamic_dim(), batch_dim());
799 return Derived(transpose(d1, d2), dynamic_sizes(), intmd_dim());
800}
801
802template <class Derived>
805{
806 d1 = utils::normalize_dim(d1, batch_dim(), dim());
807 d2 = utils::normalize_dim(d2, batch_dim(), dim());
808 return neml2::Tensor(transpose(d1, d2), dynamic_sizes(), intmd_dim());
809}
810
811template <class Derived>
812Derived
814{
815 neml_assert_dbg(_intmd_dim == 0,
816 "batch_transpose is only supported when there are no intermediate dimensions.");
817 return dynamic_transpose(d1, d2);
818}
819
820template <class Derived>
821Derived
823{
824 old_dim = utils::normalize_dim(old_dim, 0, dynamic_dim());
825 new_dim = utils::normalize_dim(new_dim, 0, dynamic_dim());
826
827 auto sizes = dynamic_sizes();
828 auto from = sizes.begin() + old_dim;
829 auto to = sizes.begin() + new_dim;
830 if (from < to)
831 std::rotate(from, from + 1, to + 1);
832 else
833 std::rotate(to, from, from + 1);
834
835 return Derived(movedim(old_dim, new_dim), sizes, intmd_dim());
836}
837
838template <class Derived>
839Derived
841{
842 old_dim = utils::normalize_dim(old_dim, dynamic_dim(), batch_dim());
843 new_dim = utils::normalize_dim(new_dim, dynamic_dim(), batch_dim());
844 return Derived(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
845}
846
847template <class Derived>
850{
851 old_dim = utils::normalize_dim(old_dim, batch_dim(), dim());
852 new_dim = utils::normalize_dim(new_dim, batch_dim(), dim());
853 return neml2::Tensor(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
854}
855
856template <class Derived>
857Derived
859{
860 neml_assert_dbg(_intmd_dim == 0,
861 "batch_movedim is only supported when there are no intermediate dimensions.");
862 return dynamic_movedim(old_dim, new_dim);
863}
864
865template <class Derived>
866Derived
868{
869 if (dynamic_dim() == 1)
870 return *this;
871
872 auto n = utils::traceable_numel(dynamic_sizes());
873 if (const auto * const nt = n.traceable())
874 jit::tracer::ArgumentStash::stashIntArrayRefElem("shape", 1 + static_dim(), 0, *nt);
875
876 auto sizes = utils::add_shapes(n.concrete(), intmd_sizes(), base_sizes());
877 return Derived(reshape(sizes), {n}, intmd_dim());
878}
879
880template <class Derived>
881Derived
883{
884 if (intmd_dim() == 1)
885 return *this;
886 return intmd_reshape(utils::numel(intmd_sizes()));
887}
888
889template <class Derived>
892{
893 if (base_dim() == 1)
894 return *this;
895 return base_reshape(utils::numel(base_sizes()));
896}
897
898template <class Derived>
899Derived
901{
902 if (intmd_dim() == 0 && dynamic_dim() == 1)
903 return *this;
904 return batch_reshape({utils::traceable_numel(batch_sizes())}, {});
905}
906
907template <class Derived>
910{
911 if (intmd_dim() == 0 && base_dim() == 1)
912 return *this;
913 return static_reshape({}, utils::numel(static_sizes()));
914}
915
916template <class Derived>
917Derived
919{
920 return Derived(-ATensor(*this), dynamic_sizes(), intmd_dim());
921}
922
923} // end namespace neml2
Derived batch_squeeze(Size d) const
Definition TensorBaseImpl.h:724
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:909
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:315
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
neml2::Tensor base_flatten() const
Definition TensorBaseImpl.h:891
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:858
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:849
Derived batch_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:296
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:760
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:804
Derived batch_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:335
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
Derived dynamic_flatten() const
Definition TensorBaseImpl.h:867
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:305
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:840
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:601
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:608
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:734
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:918
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:509
Derived contiguous() const
Definition TensorBaseImpl.h:126
Derived batch_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:813
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:443
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:594
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:587
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership)
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:426
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:464
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_flatten() const
Definition TensorBaseImpl.h:882
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:748
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:822
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:622
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:692
Derived batch_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:772
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:383
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:795
void batch_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:401
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:573
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:900
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:325
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:674
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:708
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:345
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:363
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:782
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:716
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:580
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:640
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:419
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}) const
Definition TensorBaseImpl.h:484
Definition Tensor.h:47
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition indexing.h:38
Definition BufferStore.h:43
TensorShape add_shapes(const S &...)
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Definition shape_utils.cxx:49
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
Definition shape_utils.cxx:34
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
Definition shape_utils.cxx:64
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition jit.h:86
TraceableSize traceable_numel(const TraceableTensorShape &shape)
Get the number of elements in a tensor shape.
Definition jit.cxx:61
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:69
at::Tensor ATensor
Definition types.h:41
int64_t Size
Definition types.h:68
c10::Scalar CScalar
Definition types.h:42
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:70
void neml_assert(bool assertion, Args &&... args)
Definition assertions.h:47
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:71
TraceableTensorShape slice(std::size_t N, std::size_t M) const
Slice the shape.
Definition TraceableTensorShape.cxx:59