NEML2 2.0.0
Loading...
Searching...
No Matches
TensorBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/utils.h"
28
29namespace neml2
30{
31// Forward declarations
32template <class Derived>
33class TensorBase;
34
35class Tensor;
36
44template <class Derived>
45class TensorBase : public torch::Tensor
46{
47public:
49 TensorBase() = default;
50
52 TensorBase(const torch::Tensor & tensor, Size batch_dim);
53
55 TensorBase(const torch::Tensor & tensor, const TraceableTensorShape & batch_shape);
56
58 template <class Derived2>
60
61 TensorBase(Real) = delete;
62
64 [[nodiscard]] static Derived empty_like(const Derived & other);
66 [[nodiscard]] static Derived zeros_like(const Derived & other);
68 [[nodiscard]] static Derived ones_like(const Derived & other);
71 [[nodiscard]] static Derived full_like(const Derived & other, Real init);
72
93 [[nodiscard]] static Derived
94 linspace(const Derived & start, const Derived & end, Size nstep, Size dim = 0);
96 [[nodiscard]] static Derived
97 logspace(const Derived & start, const Derived & end, Size nstep, Size dim = 0, Real base = 10);
98
102 Derived clone(torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const;
104 Derived detach() const;
106 using torch::Tensor::detach_;
108 Derived to(const torch::TensorOptions & options) const;
110 using torch::Tensor::copy_;
112 using torch::Tensor::zero_;
114 using torch::Tensor::requires_grad;
116 using torch::Tensor::requires_grad_;
118 Derived operator-() const;
120
124 using torch::Tensor::options;
126 using torch::Tensor::scalar_type;
128 using torch::Tensor::device;
130 using torch::Tensor::dim;
132 using torch::Tensor::sizes;
134 using torch::Tensor::size;
136 bool batched() const;
138 Size batch_dim() const;
140 Size base_dim() const;
142 const TraceableTensorShape & batch_sizes() const;
144 TraceableSize batch_size(Size index) const;
148 Size base_size(Size index) const;
150 Size base_storage() const;
152
156 using torch::Tensor::index;
157 using torch::Tensor::index_put_;
164 void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor & other);
169 void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor & other);
173 Derived variable_data() const;
175
179 Derived batch_expand(const TraceableTensorShape & batch_shape) const;
181 Derived batch_expand(const TraceableSize & batch_size, Size dim) const;
187 template <class Derived2>
188 Derived batch_expand_as(const Derived2 & other) const;
190 template <class Derived2>
197 Derived batch_reshape(const TraceableTensorShape & batch_shape) const;
201 Derived batch_unsqueeze(Size d) const;
205 Derived batch_transpose(Size d1, Size d2) const;
211
212private:
214 Size _batch_dim = {};
215
217 TraceableTensorShape _batch_sizes;
218};
219
220template <class Derived>
221template <class Derived2>
223 : _batch_dim(tensor.batch_dim()),
224 _batch_sizes(tensor.batch_sizes())
225{
226 torch::Tensor::operator=(tensor);
227}
228
229template <class Derived>
230template <class Derived2>
231Derived
233{
234 return batch_expand(other.batch_sizes());
235}
236
237template <class Derived>
238template <class Derived2>
241{
242 return base_expand(other.base_sizes());
243}
244
245template <class Derived,
246 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
247Derived
248operator+(const Derived & a, const Real & b)
249{
250 return Derived(torch::operator+(a, b), a.batch_sizes());
251}
252
253template <class Derived,
254 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
255Derived
256operator+(const Real & a, const Derived & b)
257{
258 return b + a;
259}
260
261template <class Derived,
262 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
263Derived
264operator+(const Derived & a, const Derived & b)
265{
267 return Derived(torch::operator+(a, b), broadcast_batch_dim(a, b));
268}
269
270template <class Derived,
271 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
272Derived
273operator-(const Derived & a, const Real & b)
274{
275 return Derived(torch::operator-(a, b), a.batch_sizes());
276}
277
278template <class Derived,
279 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
280Derived
281operator-(const Real & a, const Derived & b)
282{
283 return -b + a;
284}
285
286template <class Derived,
287 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
288Derived
289operator-(const Derived & a, const Derived & b)
290{
292 return Derived(torch::operator-(a, b), broadcast_batch_dim(a, b));
293}
294
295template <class Derived,
296 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
297Derived
298operator*(const Derived & a, const Real & b)
299{
300 return Derived(torch::operator*(a, b), a.batch_sizes());
301}
302
303template <class Derived,
304 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
305Derived
306operator*(const Real & a, const Derived & b)
307{
308 return b * a;
309}
310
311template <class Derived,
312 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
313Derived
314operator/(const Derived & a, const Real & b)
315{
316 return Derived(torch::operator/(a, b), a.batch_sizes());
317}
318
319template <class Derived,
320 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
321Derived
322operator/(const Real & a, const Derived & b)
323{
324 return Derived(torch::operator/(a, b), b.batch_sizes());
325}
326
327template <class Derived,
328 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
329Derived
330operator/(const Derived & a, const Derived & b)
331{
333 return Derived(torch::operator/(a, b), broadcast_batch_dim(a, b));
334}
335} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition TensorBaseImpl.h:130
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBaseImpl.h:121
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:413
Derived batch_reshape(const TraceableTensorShape &batch_shape) const
Reshape batch dimensions.
Definition TensorBaseImpl.h:350
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBaseImpl.h:213
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBaseImpl.h:404
TraceableSize batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBaseImpl.h:179
Derived batch_expand_as(const Derived2 &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBase.h:232
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBaseImpl.h:206
TensorBase()=default
Default constructor.
bool batched() const
Whether the tensor is batched.
Definition TensorBaseImpl.h:151
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:158
Derived detach() const
Discard function graph.
Definition TensorBaseImpl.h:137
Derived operator-() const
Negation.
Definition TensorBaseImpl.h:423
TensorBase(Real)=delete
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBaseImpl.h:395
neml2::Tensor base_expand_copy(TensorShapeRef base_shape) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:343
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBaseImpl.h:387
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:144
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBaseImpl.h:199
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:75
const TraceableTensorShape & batch_sizes() const
Return the batch size.
Definition TensorBaseImpl.h:172
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:68
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBaseImpl.h:379
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBaseImpl.h:96
static Derived full_like(const Derived &other, Real init)
Definition TensorBaseImpl.h:89
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:232
Derived batch_expand_copy(const TraceableTensorShape &batch_shape) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBaseImpl.h:336
Size base_dim() const
Return the number of base dimensions.
Definition TensorBaseImpl.h:165
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:192
TensorBase(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition TensorBase.h:222
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBaseImpl.h:363
void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Definition TensorBaseImpl.h:251
Derived batch_expand(const TraceableTensorShape &batch_shape) const
Definition TensorBaseImpl.h:277
neml2::Tensor base_expand(TensorShapeRef base_shape) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBaseImpl.h:309
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBaseImpl.h:82
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBaseImpl.h:223
Derived2 base_expand_as(const Derived2 &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBase.h:240
Derived variable_data() const
Definition TensorBaseImpl.h:270
Definition Tensor.h:47
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Definition CrossRef.cxx:31
Vec operator*(const Derived1 &A, const Derived2 &b)
matrix-vector product
Definition R2Base.cxx:233
auto operator/(const T1 &a, const T2 &b)
Definition Variable.h:367
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35
auto operator+(const T1 &a, const T2 &b)
Definition Variable.h:364
auto operator-(const T1 &a, const T2 &b)
Definition Variable.h:365
Traceable size.
Definition types.h:52
Traceable tensor shape.
Definition types.h:81