NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <ATen/core/Tensor.h>
28
29#include "neml2/tensors/TraceableSize.h"
30#include "neml2/tensors/TraceableTensorShape.h"
31#include "neml2/tensors/functions/operators.h"
32#include "neml2/tensors/functions/logical.h"
33#include "neml2/tensors/indexing.h"
34#include "neml2/tensors/macros.h"
35
36namespace neml2
37{
38// Forward declarations
39template <class Derived>
40class TensorBase;
41
42class Tensor;
43
75template <class Derived>
76class TensorBase : public ATensor
77{
78public:
80 TensorBase() = default;
81
84
86 TensorBase(const ATensor & tensor, TraceableTensorShape dynamic_shape, Size intmd_dim);
87
89 template <class Derived2>
91 : TensorBase(tensor, tensor.dynamic_sizes(), tensor.intmd_dim())
92 {
93 }
94
95 TensorBase(double) = delete;
96 TensorBase(float) = delete;
97 TensorBase(int) = delete;
98
102 [[nodiscard]] static Derived empty_like(const Derived & other);
104 [[nodiscard]] static Derived zeros_like(const Derived & other);
106 [[nodiscard]] static Derived ones_like(const Derived & other);
109 [[nodiscard]] static Derived full_like(const Derived & other, const CScalar & init);
112 [[nodiscard]] static Derived rand_like(const Derived & other);
114
118 Derived contiguous() const;
120 Derived clone() const;
122 Derived detach() const;
124 using ATensor::detach_;
126 Derived to(const TensorOptions & options) const;
128 using ATensor::copy_;
130 using ATensor::zero_;
132 using ATensor::requires_grad;
134 using ATensor::requires_grad_;
136 Derived operator-() const;
138
142 using ATensor::defined;
144 using ATensor::options;
146 using ATensor::scalar_type;
148 using ATensor::device;
150
153 using ATensor::dim;
154 Size batch_dim() const;
155 Size base_dim() const;
156 Size dynamic_dim() const;
157 Size static_dim() const;
158 Size intmd_dim() const;
160
163 using ATensor::sizes;
166 const TraceableTensorShape & dynamic_sizes() const;
170
173 using ATensor::size;
175 Size base_size(Size i) const;
176 const TraceableSize & dynamic_size(Size i) const;
177 Size static_size(Size i) const;
178 Size intmd_size(Size i) const;
180
182 using ATensor::index;
183 using ATensor::index_put_;
184
187 Derived dynamic_index(indexing::TensorIndicesRef indices) const;
188 Derived intmd_index(indexing::TensorIndicesRef indices) const;
191
194 Derived dynamic_slice(Size d, const indexing::Slice & index) const;
195 Derived intmd_slice(Size d, const indexing::Slice & index) const;
196 neml2::Tensor base_slice(Size d, const indexing::Slice & index) const;
198
201 void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor & other);
203 void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor & other);
204 void intmd_index_put_(indexing::TensorIndicesRef indices, const CScalar & v);
205 void base_index_put_(indexing::TensorIndicesRef indices, const ATensor & other);
206 void base_index_put_(indexing::TensorIndicesRef indices, const CScalar & v);
208
210 Derived variable_data() const;
211
214 Derived dynamic_expand(const TraceableTensorShape & shape) const;
215 Derived intmd_expand(TensorShapeRef shape) const;
217 Derived batch_expand(const TraceableTensorShape & dynamic_shape,
218 TensorShapeRef intmd_shape) const;
219 neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const;
221
224 Derived dynamic_expand(const TraceableSize & size, Size d) const;
225 Derived intmd_expand(Size size, Size d) const;
226 neml2::Tensor base_expand(Size size, Size d) const;
228
231 Derived dynamic_expand_as(const neml2::Tensor & other) const;
232 Derived intmd_expand_as(const neml2::Tensor & other) const;
233 neml2::Tensor base_expand_as(const neml2::Tensor & other) const;
234 Derived batch_expand_as(const neml2::Tensor & other) const;
235 neml2::Tensor static_expand_as(const neml2::Tensor & other) const;
237
240 Derived dynamic_reshape(const TraceableTensorShape & shape) const;
241 Derived intmd_reshape(TensorShapeRef shape) const;
243 Derived batch_reshape(const TraceableTensorShape & dynamic_shape,
244 TensorShapeRef intmd_shape) const;
245 neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const;
247
250 Derived dynamic_squeeze(Size d) const;
251 Derived intmd_squeeze(Size d) const;
254
257 Derived dynamic_unsqueeze(Size d, Size n = 1) const;
258 Derived intmd_unsqueeze(Size d, Size n = 1) const;
259 neml2::Tensor base_unsqueeze(Size d, Size n = 1) const;
261
264 Derived dynamic_transpose(Size d1, Size d2) const;
265 Derived intmd_transpose(Size d1, Size d2) const;
268
271 Derived dynamic_movedim(Size old_dim, Size new_dim) const;
272 Derived intmd_movedim(Size old_dim, Size new_dim) const;
273 neml2::Tensor base_movedim(Size old_dim, Size new_dim) const;
275
278 Derived dynamic_flatten() const;
279 Derived intmd_flatten() const;
287 Derived batch_flatten() const;
296
297protected:
299 void validate_shapes_and_dims() const;
300
301private:
303 TraceableTensorShape _dynamic_sizes;
304
306 // BTW, "intmd" is an abbreviation for "intermediate" found in the Merriam-Webster dictionary :)
307 Size _intmd_dim = 0;
308};
309
310// Export TensorBase so other TU don't repeat the instantiation
311#define EXPORT_TENSORBASE(T) extern template class TensorBase<T>
312FOR_ALL_TENSORBASE(EXPORT_TENSORBASE);
313#undef EXPORT_TENSORBASE
314} // namespace neml2
NEML2's enhanced tensor type.
Definition TensorBase.h:77
neml2::Tensor static_flatten() const
Flatten static dimensions.
Definition TensorBaseImpl.h:834
Derived intmd_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:306
TraceableTensorShape batch_sizes() const
Definition TensorBaseImpl.h:189
neml2::Tensor base_flatten() const
Definition TensorBaseImpl.h:816
Size static_size(Size i) const
Definition TensorBaseImpl.h:250
Derived batch_expand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:447
Size intmd_size(Size i) const
Definition TensorBaseImpl.h:258
neml2::Tensor base_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:783
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
neml2::Tensor base_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:713
TensorBase(double)=delete
neml2::Tensor base_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:747
TensorShapeRef intmd_sizes() const
Definition TensorBaseImpl.h:217
const TraceableSize & dynamic_size(Size i) const
Definition TensorBaseImpl.h:242
Derived dynamic_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:266
Derived dynamic_flatten() const
Definition TensorBaseImpl.h:792
TensorBase(float)=delete
TensorBase()=default
Default constructor.
Derived dynamic_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:296
Size batch_dim() const
Definition TensorBaseImpl.h:154
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:140
TraceableSize batch_size(Size i) const
Definition TensorBaseImpl.h:224
const TraceableTensorShape & dynamic_sizes() const
Definition TensorBaseImpl.h:203
Derived intmd_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:774
neml2::Tensor static_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:564
Derived dynamic_reshape(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:571
Derived dynamic_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:687
TensorShapeRef static_sizes() const
Definition TensorBaseImpl.h:210
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:843
neml2::Tensor static_expand(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:472
Derived contiguous() const
Definition TensorBaseImpl.h:126
TensorBase(int)=delete
Derived intmd_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:406
Derived batch_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:557
neml2::Tensor base_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:550
Size base_size(Size i) const
Definition TensorBaseImpl.h:234
Derived clone() const
Clone (take ownership)
Definition TensorBaseImpl.h:133
Derived dynamic_expand(const TraceableTensorShape &shape) const
Definition TensorBaseImpl.h:389
neml2::Tensor base_expand(TensorShapeRef shape) const
Definition TensorBaseImpl.h:427
Derived intmd_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:276
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:147
Derived intmd_flatten() const
Definition TensorBaseImpl.h:807
Derived intmd_unsqueeze(Size d, Size n=1) const
Definition TensorBaseImpl.h:701
Derived dynamic_movedim(Size old_dim, Size new_dim) const
Definition TensorBaseImpl.h:756
Derived intmd_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:585
Derived dynamic_squeeze(Size d) const
Definition TensorBaseImpl.h:655
void base_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:364
Derived intmd_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:738
Derived dynamic_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:536
Derived batch_flatten() const
Flatten batch dimensions.
Definition TensorBaseImpl.h:825
neml2::Tensor base_slice(Size d, const indexing::Slice &index) const
Definition TensorBaseImpl.h:316
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor static_reshape(TensorShapeRef intmd_shape, TensorShapeRef base_shape) const
Definition TensorBaseImpl.h:637
Size base_dim() const
Definition TensorBaseImpl.h:161
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition TensorBaseImpl.h:71
Size static_dim() const
Definition TensorBaseImpl.h:175
TensorShapeRef base_sizes() const
Definition TensorBaseImpl.h:196
TensorBase(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition TensorBase.h:90
Derived intmd_squeeze(Size d) const
Definition TensorBaseImpl.h:671
void dynamic_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:326
Derived batch_reshape(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape) const
Definition TensorBaseImpl.h:618
void intmd_index_put_(indexing::TensorIndicesRef indices, const ATensor &other)
Definition TensorBaseImpl.h:344
Derived dynamic_transpose(Size d1, Size d2) const
Definition TensorBaseImpl.h:725
neml2::Tensor base_squeeze(Size d) const
Definition TensorBaseImpl.h:679
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Definition TensorBaseImpl.h:287
Derived intmd_expand_as(const neml2::Tensor &other) const
Definition TensorBaseImpl.h:543
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:603
Derived variable_data() const
Variable data without function graph.
Definition TensorBaseImpl.h:382
Definition Tensor.h:47
static Derived full_like(const Derived &other, const CScalar &init)
Definition TensorBaseImpl.h:112
static Derived rand_like(const Derived &other)
Definition TensorBaseImpl.h:119
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:98
static Derived empty_like(const Derived &other)
Definition TensorBaseImpl.h:91
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:105
c10::ArrayRef< TensorIndex > TensorIndicesRef
Definition indexing.h:39
Definition DiagnosticsInterface.cxx:30
at::Tensor ATensor
Definition types.h:38
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38