NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBaseImpl.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
30#pragma once
31
32#include <torch/csrc/jit/frontend/tracer.h>
33
34#include "neml2/tensors/TraceableTensorShape.h"
35#include "neml2/tensors/Tensor.h"
36#include "neml2/tensors/TensorBase.h"
37#include "neml2/tensors/Scalar.h"
38#include "neml2/tensors/jit.h"
39#include "neml2/tensors/shape_utils.h"
40#include "neml2/misc/assertions.h"
41
42namespace neml2::jit
43{
44using namespace torch::jit;
45}
46
47namespace neml2
48{
49template <class Derived>
50TensorBase<Derived>::TensorBase(const ATensor & tensor, Size dynamic_dim, Size intmd_dim)
51 : ATensor(tensor),
52 _dynamic_sizes(utils::extract_traceable_sizes(tensor, 0, dynamic_dim)),
53 _intmd_dim(intmd_dim)
54{
56}
57
58template <class Derived>
60 TraceableTensorShape dynamic_shape,
61 Size intmd_dim)
62 : ATensor(tensor),
63 _dynamic_sizes(std::move(dynamic_shape)),
64 _intmd_dim(intmd_dim)
65{
67}
68
69template <class Derived>
70void
72{
73 neml_assert(dim() >= dynamic_dim() + intmd_dim(),
74 "Tensor dimension ",
75 dim(),
76 " is not sufficient for the requested number of dynamic dimensions (",
77 dynamic_dim(),
78 ") and intermediate dimensions (",
79 intmd_dim(),
80 ")");
81 neml_assert(dynamic_sizes() == sizes().slice(0, dynamic_dim()),
82 "Tensor of shape ",
83 sizes(),
84 " is incompatible with dynamic shape ",
85 dynamic_sizes(),
86 ". The leading dimensions must match.");
87}
88
89template <class Derived>
90Derived
91TensorBase<Derived>::empty_like(const Derived & other)
92{
93 return Derived(at::empty_like(other), other.dynamic_sizes(), other.intmd_dim());
94}
95
96template <class Derived>
97Derived
98TensorBase<Derived>::zeros_like(const Derived & other)
99{
100 return Derived(at::zeros_like(other), other.dynamic_sizes(), other.intmd_dim());
101}
103template <class Derived>
104Derived
105TensorBase<Derived>::ones_like(const Derived & other)
107 return Derived(at::ones_like(other), other.dynamic_sizes(), other.intmd_dim());
108}
110template <class Derived>
111Derived
112TensorBase<Derived>::full_like(const Derived & other, const CScalar & init)
113{
114 return Derived(at::full_like(other, init), other.dynamic_sizes(), other.intmd_dim());
115}
116
117template <class Derived>
118Derived
119TensorBase<Derived>::rand_like(const Derived & other)
121 return Derived(at::rand_like(other), other.dynamic_sizes(), other.intmd_dim());
123
124template <class Derived>
125Derived
127{
128 return Derived(ATensor::contiguous(), dynamic_sizes(), intmd_dim());
129}
130
131template <class Derived>
132Derived
134{
135 return Derived(ATensor::clone(), dynamic_sizes(), intmd_dim());
137
138template <class Derived>
139Derived
141{
142 return Derived(ATensor::detach(), dynamic_sizes(), intmd_dim());
143}
144
145template <class Derived>
146Derived
148{
149 return Derived(ATensor::to(options), dynamic_sizes(), intmd_dim());
150}
151
152template <class Derived>
153Size
156 return dynamic_dim() + intmd_dim();
159template <class Derived>
160Size
162{
163 return dim() - batch_dim();
166template <class Derived>
169{
170 return static_cast<Size>(_dynamic_sizes.size());
171}
172
173template <class Derived>
177 return dim() - dynamic_dim();
179
180template <class Derived>
181Size
183{
184 return _intmd_dim;
185}
186
187template <class Derived>
190{
191 return utils::add_traceable_shapes(dynamic_sizes(), intmd_sizes());
192}
193
194template <class Derived>
197{
198 return sizes().slice(batch_dim());
199}
200
201template <class Derived>
205 return _dynamic_sizes;
207
208template <class Derived>
211{
212 return sizes().slice(dynamic_dim());
213}
215template <class Derived>
218{
219 return sizes().slice(dynamic_dim(), intmd_dim());
220}
221
222template <class Derived>
226 i = utils::normalize_dim(i, 0, batch_dim());
227 if (i < dynamic_dim())
228 return _dynamic_sizes[i];
229 return size(i);
230}
232template <class Derived>
236 i = utils::normalize_dim(i, batch_dim(), dim());
237 return size(i);
238}
239
240template <class Derived>
244 i = utils::normalize_dim(i, 0, dynamic_dim());
245 return _dynamic_sizes[i];
246}
247
248template <class Derived>
249Size
252 i = utils::normalize_dim(i, dynamic_dim(), dim());
253 return size(i);
254}
255
256template <class Derived>
260 i = utils::normalize_dim(i, dynamic_dim(), batch_dim());
261 return size(i);
262}
263
264template <class Derived>
265Derived
267{
268 indexing::TensorIndices indices_vec(indices);
269 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
270 auto res = this->index(indices_vec);
271 return Derived(res, res.dim() - static_dim(), intmd_dim());
274template <class Derived>
275Derived
277{
278 indexing::TensorIndices indices_vec(dynamic_dim(), indexing::Slice());
279 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
280 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
281 auto res = this->index(indices_vec);
282 return Derived(res, dynamic_sizes(), res.dim() - dynamic_dim() - base_dim());
283}
284
285template <class Derived>
288{
289 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
290 indices2.insert(indices2.end(), indices.begin(), indices.end());
291 return neml2::Tensor(this->index(indices2), dynamic_sizes(), intmd_dim());
292}
293
294template <class Derived>
295Derived
296TensorBase<Derived>::dynamic_slice(Size d, const indexing::Slice & index) const
297{
298 d = utils::normalize_dim(d, 0, dynamic_dim());
299 auto res = this->slice(
300 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
301 return Derived(res, res.dim() - static_dim(), intmd_dim());
302}
303
304template <class Derived>
305Derived
306TensorBase<Derived>::intmd_slice(Size d, const indexing::Slice & index) const
307{
308 d = utils::normalize_dim(d, dynamic_dim(), batch_dim());
309 auto res = this->slice(
310 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
311 return Derived(res, dynamic_sizes(), intmd_dim());
312}
313
314template <class Derived>
316TensorBase<Derived>::base_slice(Size d, const indexing::Slice & index) const
317{
318 d = utils::normalize_dim(d, batch_dim(), dim());
319 auto res = this->slice(
320 d, index.start().expect_int(), index.stop().expect_int(), index.step().expect_int());
321 return neml2::Tensor(res, dynamic_sizes(), intmd_dim());
322}
323
324template <class Derived>
325void
327{
328 indexing::TensorIndices indices_vec(indices);
329 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
330 this->index_put_(indices_vec, other);
331}
332
333template <class Derived>
334void
336{
337 indexing::TensorIndices indices_vec(indices);
338 indices_vec.insert(indices_vec.end(), static_dim(), indexing::Slice());
339 this->index_put_(indices_vec, v);
340}
341
342template <class Derived>
343void
345{
346 indexing::TensorIndices indices_vec(dynamic_dim(), indexing::Slice());
347 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
348 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
349 this->index_put_(indices_vec, other);
350}
351
352template <class Derived>
353void
355{
356 indexing::TensorIndices indices_vec(dynamic_dim(), indexing::Slice());
357 indices_vec.insert(indices_vec.end(), indices.begin(), indices.end());
358 indices_vec.insert(indices_vec.end(), base_dim(), indexing::Slice());
359 this->index_put_(indices_vec, v);
360}
361
362template <class Derived>
363void
365{
366 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
367 indices2.insert(indices2.end(), indices.begin(), indices.end());
368 this->index_put_(indices2, other);
369}
370
371template <class Derived>
372void
374{
375 indexing::TensorIndices indices2(batch_dim(), indexing::Slice());
376 indices2.insert(indices2.end(), indices.begin(), indices.end());
377 this->index_put_(indices2, v);
378}
379
380template <class Derived>
381Derived
383{
384 return Derived(ATensor::variable_data(), dynamic_sizes(), intmd_dim());
385}
386
387template <class Derived>
388Derived
390{
391 // We don't want to touch the other dimensions, so put -1 for them.
392 auto net = shape.concrete();
393 net.insert(net.end(), static_dim(), -1);
394
395 // Record the dynamic sizes in the traced graph if we are tracing
396 if (jit::tracer::isTracing())
397 for (std::size_t i = 0; i < shape.size(); ++i)
398 if (const auto * const si = shape[i].traceable())
399 jit::tracer::ArgumentStash::stashIntArrayRefElem("size", net.size(), i, *si);
400
401 return Derived(expand(net), shape, intmd_dim());
402}
403
404template <class Derived>
405Derived
407{
408 if (intmd_sizes() == shape)
409 return *this;
410
411 // Unsqueeze missing dimensions
412 neml_assert_dbg(Size(shape.size()) >= intmd_dim(),
413 "Invalid intermediate shape to expand. Expected at least ",
414 intmd_dim(),
415 " dimensions.");
416 auto tmp = intmd_unsqueeze(0, shape.size() - intmd_dim());
417
418 // We don't want to touch the other dimensions, so put -1 for them.
419 TensorShape net(dynamic_dim(), -1);
420 net.insert(net.end(), shape.begin(), shape.end());
421 net.insert(net.end(), base_dim(), -1);
422 return Derived(tmp.expand(net), dynamic_sizes(), Size(shape.size()));
423}
424
425template <class Derived>
428{
429 if (base_sizes() == shape)
430 return *this;
431
432 // Unsqueeze missing dimensions
433 neml_assert_dbg(Size(shape.size()) >= base_dim(),
434 "Invalid base shape to expand. Expected at least ",
435 base_dim(),
436 " dimensions.");
437 auto tmp = base_unsqueeze(0, shape.size() - base_dim());
438
439 // We don't want to touch the batch dimensions, so put -1 for them.
440 TensorShape net(batch_dim(), -1);
441 net.insert(net.end(), shape.begin(), shape.end());
442 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), intmd_dim());
443}
444
445template <class Derived>
446Derived
448 TensorShapeRef intmd_shape) const
449{
450 // Unsqueeze missing dimensions
451 neml_assert_dbg(Size(intmd_shape.size()) >= intmd_dim(),
452 "Invalid intermediate shape to expand. Expected at least ",
453 intmd_dim(),
454 " dimensions.");
455 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
456
457 // We don't want to touch the other dimensions, so put -1 for them.
458 auto net = utils::add_shapes(dynamic_shape.concrete(), intmd_shape);
459 net.insert(net.end(), base_dim(), -1);
460
461 // Record the dynamic sizes in the traced graph if we are tracing
462 if (jit::tracer::isTracing())
463 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
464 if (const auto * const si = dynamic_shape[i].traceable())
465 jit::tracer::ArgumentStash::stashIntArrayRefElem("size", net.size(), i, *si);
466
467 return Derived(tmp.expand(net), dynamic_shape, tmp.intmd_dim());
468}
469
470template <class Derived>
473{
474 auto net = utils::add_shapes(intmd_shape, base_shape);
475 if (static_sizes() == net)
476 return *this;
477
478 // Unsqueeze missing dimensions
479 neml_assert_dbg(Size(intmd_shape.size()) >= intmd_dim(),
480 "Invalid intermediate shape to expand. Expected at least ",
481 intmd_dim(),
482 " dimensions.");
483 neml_assert_dbg(Size(base_shape.size()) >= base_dim(),
484 "Invalid base shape to expand. Expected at least ",
485 base_dim(),
486 " dimensions.");
487 auto tmp = intmd_unsqueeze(0, intmd_shape.size() - intmd_dim());
488 tmp = tmp.base_unsqueeze(0, base_shape.size() - base_dim());
489
490 // We don't want to touch the other dimensions, so put -1 for them.
491 net.insert(net.begin(), dynamic_dim(), -1);
492 return neml2::Tensor(tmp.expand(net), dynamic_sizes(), tmp.intmd_dim());
493}
494
495template <class Derived>
496Derived
498{
499 if (dynamic_size(d) == size)
500 return Derived(*this);
501
502 d = utils::normalize_dim(d, 0, dynamic_dim());
503 auto shape = dynamic_sizes();
504 shape[d] = size;
505 return dynamic_expand(shape);
506}
507
508template <class Derived>
509Derived
511{
512 if (intmd_size(d) == size)
513 return *this;
514
515 d = utils::normalize_dim(d, dynamic_dim(), batch_dim());
516 TensorShape net(dim(), -1);
517 net[d] = size;
518 return Derived(expand(net), dynamic_sizes(), intmd_dim());
519}
520
521template <class Derived>
524{
525 if (base_size(d) == size)
526 return *this;
527
528 d = utils::normalize_dim(d, batch_dim(), dim());
529 TensorShape net(dim(), -1);
530 net[d] = size;
531 return neml2::Tensor(expand(net), dynamic_sizes(), intmd_dim());
532}
533
534template <class Derived>
535Derived
537{
538 return dynamic_expand(other.dynamic_sizes());
539}
540
541template <class Derived>
542Derived
544{
545 return intmd_expand(other.intmd_sizes());
546}
547
548template <class Derived>
551{
552 return base_expand(other.base_sizes());
553}
554
555template <class Derived>
556Derived
558{
559 return batch_expand(other.dynamic_sizes(), other.intmd_sizes());
560}
561
562template <class Derived>
565{
566 return static_expand(other.intmd_sizes(), other.base_sizes());
567}
568
569template <class Derived>
570Derived
572{
573 // Record the dynamic sizes in the traced graph if we are tracing
574 if (jit::tracer::isTracing())
575 for (std::size_t i = 0; i < shape.size(); ++i)
576 if (const auto * const si = shape[i].traceable())
577 jit::tracer::ArgumentStash::stashIntArrayRefElem(
578 "shape", shape.size() + static_dim(), i, *si);
579
580 return Derived(reshape(utils::add_shapes(shape.concrete(), static_sizes())), shape, intmd_dim());
581}
582
583template <class Derived>
584Derived
586{
587 auto intmd_dim = Size(shape.size());
588
589 // Record the dynamic sizes in the traced graph if we are tracing
590 if (jit::tracer::isTracing())
591 for (Size i = 0; i < dynamic_dim(); ++i)
592 if (const auto * const si = dynamic_size(i).traceable())
593 jit::tracer::ArgumentStash::stashIntArrayRefElem(
594 "shape", dynamic_dim() + intmd_dim + base_dim(), i, *si);
595
596 return neml2::Tensor(reshape(utils::add_shapes(dynamic_sizes().concrete(), shape, base_sizes())),
597 dynamic_sizes(),
598 intmd_dim);
599}
600
601template <class Derived>
604{
605 // Record the dynamic sizes in the traced graph if we are tracing
606 if (jit::tracer::isTracing())
607 for (Size i = 0; i < dynamic_dim(); ++i)
608 if (const auto * const si = dynamic_size(i).traceable())
609 jit::tracer::ArgumentStash::stashIntArrayRefElem(
610 "shape", batch_dim() + shape.size(), i, *si);
611
612 return neml2::Tensor(
613 reshape(utils::add_shapes(batch_sizes().concrete(), shape)), dynamic_sizes(), intmd_dim());
614}
615
616template <class Derived>
617Derived
619 TensorShapeRef intmd_shape) const
620{
621 auto intmd_dim = Size(intmd_shape.size());
622
623 // Record the dynamic sizes in the traced graph if we are tracing
624 if (jit::tracer::isTracing())
625 for (std::size_t i = 0; i < dynamic_shape.size(); ++i)
626 if (const auto * const si = dynamic_shape[i].traceable())
627 jit::tracer::ArgumentStash::stashIntArrayRefElem(
628 "shape", dynamic_shape.size() + intmd_dim + base_dim(), i, *si);
629
630 return Derived(reshape(utils::add_shapes(dynamic_shape.concrete(), intmd_shape, base_sizes())),
631 dynamic_shape,
632 intmd_dim);
633}
634
635template <class Derived>
638{
639 auto intmd_dim = Size(intmd_shape.size());
640
641 // Record the dynamic sizes in the traced graph if we are tracing
642 if (jit::tracer::isTracing())
643 for (Size i = 0; i < dynamic_dim(); ++i)
644 if (const auto * const si = dynamic_size(i).traceable())
645 jit::tracer::ArgumentStash::stashIntArrayRefElem(
646 "shape", dynamic_dim() + intmd_dim + base_shape.size(), i, *si);
647
648 return Derived(reshape(utils::add_shapes(dynamic_sizes().concrete(), intmd_shape, base_shape)),
649 dynamic_sizes(),
650 intmd_dim);
651}
652
653template <class Derived>
654Derived
656{
657 d = utils::normalize_dim(d, 0, dynamic_dim());
658 neml_assert(dynamic_size(d) == 1,
659 "Cannot squeeze dynamic dimension ",
660 d,
661 " with size ",
662 dynamic_size(d),
663 ". Only dimensions of size 1 can be squeezed.");
664 auto sizes = dynamic_sizes();
665 sizes.erase(sizes.begin() + d); // Remove the squeezed dimension
666 return Derived(squeeze(d), sizes, intmd_dim());
667}
668
669template <class Derived>
670Derived
672{
673 d = utils::normalize_dim(d, dynamic_dim(), batch_dim());
674 return Derived(squeeze(d), dynamic_sizes(), intmd_dim() - 1);
675}
676
677template <class Derived>
680{
681 d = utils::normalize_dim(d, batch_dim(), dim());
682 return neml2::Tensor(squeeze(d), dynamic_sizes(), intmd_dim());
683}
684
685template <class Derived>
686Derived
688{
689 neml_assert(n >= 0, "Number of dimensions to unsqueeze must be non-negative.");
690 at::Tensor t = *this;
691 d = utils::normalize_itr(d, 0, dynamic_dim());
692 for (Size i = 0; i < n; ++i)
693 t = t.unsqueeze(d);
694 auto B = dynamic_sizes();
695 B.insert(B.begin() + d, n, 1);
696 return Derived(t, B, intmd_dim());
697}
698
699template <class Derived>
700Derived
702{
703 neml_assert(n >= 0, "Number of dimensions to unsqueeze must be non-negative.");
704 at::Tensor t = *this;
705 d = utils::normalize_itr(d, dynamic_dim(), batch_dim());
706 for (Size i = 0; i < n; ++i)
707 t = t.unsqueeze(d);
708 return Derived(t, dynamic_sizes(), intmd_dim() + n);
709}
710
711template <class Derived>
714{
715 neml_assert(n >= 0, "Number of dimensions to unsqueeze must be non-negative.");
716 at::Tensor t = *this;
717 d = utils::normalize_itr(d, batch_dim(), dim());
718 for (Size i = 0; i < n; ++i)
719 t = t.unsqueeze(d);
720 return neml2::Tensor(t, dynamic_sizes(), intmd_dim());
721}
722
723template <class Derived>
724Derived
726{
727 d1 = utils::normalize_dim(d1, 0, dynamic_dim());
728 d2 = utils::normalize_dim(d2, 0, dynamic_dim());
729
730 auto sizes = dynamic_sizes();
731 std::swap(sizes[d1], sizes[d2]);
732
733 return Derived(transpose(d1, d2), sizes, intmd_dim());
734}
735
736template <class Derived>
737Derived
739{
740 d1 = utils::normalize_dim(d1, dynamic_dim(), batch_dim());
741 d2 = utils::normalize_dim(d2, dynamic_dim(), batch_dim());
742 return Derived(transpose(d1, d2), dynamic_sizes(), intmd_dim());
743}
744
745template <class Derived>
748{
749 d1 = utils::normalize_dim(d1, batch_dim(), dim());
750 d2 = utils::normalize_dim(d2, batch_dim(), dim());
751 return neml2::Tensor(transpose(d1, d2), dynamic_sizes(), intmd_dim());
752}
753
754template <class Derived>
755Derived
757{
758 old_dim = utils::normalize_dim(old_dim, 0, dynamic_dim());
759 new_dim = utils::normalize_dim(new_dim, 0, dynamic_dim());
760
761 auto sizes = dynamic_sizes();
762 auto from = sizes.begin() + old_dim;
763 auto to = sizes.begin() + new_dim;
764 if (from < to)
765 std::rotate(from, from + 1, to + 1);
766 else
767 std::rotate(to, from, from + 1);
768
769 return Derived(movedim(old_dim, new_dim), sizes, intmd_dim());
770}
771
772template <class Derived>
773Derived
775{
776 old_dim = utils::normalize_dim(old_dim, dynamic_dim(), batch_dim());
777 new_dim = utils::normalize_dim(new_dim, dynamic_dim(), batch_dim());
778 return Derived(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
779}
780
781template <class Derived>
784{
785 old_dim = utils::normalize_dim(old_dim, batch_dim(), dim());
786 new_dim = utils::normalize_dim(new_dim, batch_dim(), dim());
787 return neml2::Tensor(movedim(old_dim, new_dim), dynamic_sizes(), intmd_dim());
788}
789
790template <class Derived>
791Derived
793{
794 if (dynamic_dim() == 1)
795 return *this;
796
797 auto n = utils::traceable_numel(dynamic_sizes());
798 if (const auto * const nt = n.traceable())
799 jit::tracer::ArgumentStash::stashIntArrayRefElem("shape", 1 + static_dim(), 0, *nt);
800
801 auto sizes = utils::add_shapes(n.concrete(), intmd_sizes(), base_sizes());
802 return Derived(reshape(sizes), {n}, intmd_dim());
803}
804
805template <class Derived>
806Derived
808{
809 if (intmd_dim() == 1)
810 return *this;
811 return intmd_reshape(utils::numel(intmd_sizes()));
812}
813
814template <class Derived>
817{
818 if (base_dim() == 1)
819 return *this;
820 return base_reshape(utils::numel(base_sizes()));
821}
822
823template <class Derived>
824Derived
826{
827 if (intmd_dim() == 0 && dynamic_dim() == 1)
828 return *this;
829 return batch_reshape({utils::traceable_numel(batch_sizes())}, {});
830}
831
832template <class Derived>
835{
836 if (intmd_dim() == 0 && base_dim() == 1)
837 return *this;
838 return static_reshape({}, utils::numel(static_sizes()));
839}
840
841template <class Derived>
842Derived
844{
845 return Derived(-ATensor(*this), dynamic_sizes(), intmd_dim());
846}
847
848} // end namespace neml2
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:834
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:306
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
neml2::Tensor base_flatten() const
Definition TensorBaseImpl.h:816
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:447
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:783
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:713
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:747
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
Derived dynamic_flatten() const
Definition TensorBaseImpl.h:792
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:296
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:774
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:564
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:571
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:687
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:843
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:472
Derived contiguous() const
Definition TensorBaseImpl.h:126
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:406
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:557
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:550
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership)
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:389
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:427
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_flatten() const
Definition TensorBaseImpl.h:807
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:701
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:756
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:585
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:364
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:738
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:536
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:825
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:316
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:637
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:671
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:326
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:618
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:344
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:725
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:679
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:543
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:603
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:382
Definition Tensor.h:47
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
c10::SmallVector< TensorIndex, 8 > TensorIndices
Definition indexing.h:38
Definition BufferStore.h:43
TensorShape add_shapes(const S &...)
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Definition shape_utils.cxx:49
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
Definition shape_utils.cxx:34
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
Definition shape_utils.cxx:64
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition jit.h:86
TraceableSize traceable_numel(const TraceableTensorShape &shape)
Get the number of elements in a tensor shape.
Definition jit.cxx:61
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
at::Tensor ATensor
Definition types.h:38
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
void neml_assert(bool assertion, Args &&... args)
Definition assertions.h:47
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38
TensorShape concrete() const
Definition TraceableTensorShape.cxx:71
TraceableTensorShape slice(std::size_t N, std::size_t M) const
Slice the shape.
Definition TraceableTensorShape.cxx:59