NEML2 2.0.0
Loading...
Searching...
No Matches
PrimitiveTensor.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <ATen/ScalarOps.h>
28#include <torch/csrc/autograd/variable.h>
29#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
30#include <tuple>
31#include <type_traits>
32
33#include "neml2/misc/defaults.h"
34#include "neml2/misc/errors.h"
35#include "neml2/misc/types.h"
36#include "neml2/tensors/Tensor.h"
37#include "neml2/tensors/functions/utils.h"
38#include "neml2/tensors/functions/stack.h"
39
40namespace neml2
41{
42using TensorDataContainer = torch::detail::TensorDataContainer;
43
49template <class Derived, Size... S>
50class PrimitiveTensor : public TensorBase<Derived>
51{
52public:
54 using base_sizes_sequence = std::integer_sequence<Size, S...>;
55
57 static inline const TensorShape const_base_sizes = {S...};
58
60 static constexpr Size const_base_dim = sizeof...(S);
61
63 static constexpr Size const_base_numel = (1 * ... * S);
64
66 PrimitiveTensor() = default;
67
70
73
75 PrimitiveTensor(const ATensor & tensor,
76 const TraceableTensorShape & dynamic_shape,
78
80 template <class Derived2>
82
84 operator neml2::Tensor() const;
85
87 [[nodiscard]] static Derived create(const TensorDataContainer & data,
88 Size intmd_dim = 0,
89 const TensorOptions & options = default_tensor_options());
90
93 [[nodiscard]] static Derived empty(const TensorOptions & options = default_tensor_options());
94 [[nodiscard]] static Derived empty(const TraceableTensorShape & dynamic_shape,
95 TensorShapeRef intmd_shape = {},
96 const TensorOptions & options = default_tensor_options());
98
101 [[nodiscard]] static Derived zeros(const TensorOptions & options = default_tensor_options());
102 [[nodiscard]] static Derived zeros(const TraceableTensorShape & dynamic_shape,
103 TensorShapeRef intmd_shape = {},
104 const TensorOptions & options = default_tensor_options());
106
109 [[nodiscard]] static Derived ones(const TensorOptions & options = default_tensor_options());
110 [[nodiscard]] static Derived ones(const TraceableTensorShape & dynamic_shape,
111 TensorShapeRef intmd_shape = {},
112 const TensorOptions & options = default_tensor_options());
114
117 [[nodiscard]] static Derived full(const CScalar & init,
118 const TensorOptions & options = default_tensor_options());
119 [[nodiscard]] static Derived full(const TraceableTensorShape & dynamic_shape,
120 TensorShapeRef intmd_shape,
121 const CScalar & init,
122 const TensorOptions & options = default_tensor_options());
124
127 [[nodiscard]] static Derived rand(const TensorOptions & options = default_tensor_options());
128 [[nodiscard]] static Derived rand(const TraceableTensorShape & dynamic_shape,
129 TensorShapeRef intmd_shape,
130 const TensorOptions & options = default_tensor_options());
132
135 template <typename... Args,
136 typename = std::enable_if_t<(sizeof...(Args) == const_base_numel ||
137 sizeof...(Args) == const_base_numel + 1)>>
138 [[nodiscard]] static Derived fill(Args &&... args);
140
142 [[nodiscard]] static Derived einsum(c10::string_view equation, TensorList tensors);
143
145 template <typename... Args>
146 Scalar operator()(Args... i) const;
147
148protected:
151};
152
154// Implementations
156
157template <class Derived, Size... S>
159 : TensorBase<Derived>(tensor, tensor.dim() - const_base_dim - intmd_dim, intmd_dim)
160{
162}
163
164template <class Derived, Size... S>
166 Size dynamic_dim,
167 Size intmd_dim)
168 : TensorBase<Derived>(tensor, dynamic_dim, intmd_dim)
169{
170 if (dynamic_dim + intmd_dim + const_base_dim != tensor.dim())
171 throw NEMLException("Inconsistent dimensions when constructing PrimitiveTensor. Expected "
172 "tensor to have dynamic dimension " +
173 std::to_string(dynamic_dim) + ", intmd dimension " +
174 std::to_string(intmd_dim) + ", and base dimension " +
175 std::to_string(const_base_dim) + ", but tensor has " +
176 std::to_string(tensor.dim()) + " dimensions.");
178}
179
180template <class Derived, Size... S>
182 const TraceableTensorShape & dynamic_shape,
183 Size intmd_dim)
184 : TensorBase<Derived>(tensor, dynamic_shape, intmd_dim)
185{
187}
188
189template <class Derived, Size... S>
190template <class Derived2>
196
197template <class Derived, Size... S>
198void
200{
201#ifndef NDEBUG
202 if (this->base_sizes() != const_base_sizes)
203 throw NEMLException("Base shape mismatch");
204#endif
205}
206
207template <class Derived, Size... S>
208PrimitiveTensor<Derived, S...>::operator neml2::Tensor() const
209{
210 return neml2::Tensor(*this, this->dynamic_sizes(), this->intmd_dim());
211}
212
213template <class Derived, Size... S>
214Derived
216 Size intmd_dim,
217 const TensorOptions & options)
218{
219 return Derived(torch::autograd::make_variable(
220 data.convert_to_tensor(options.requires_grad(false)), options.requires_grad()),
221 intmd_dim)
222 .clone(); // clone to take ownership of the data
223}
224
225template <class Derived, Size... S>
226Derived
228{
229 return Tensor::empty(const_base_sizes, options);
230}
231
232template <class Derived, Size... S>
233Derived
235 TensorShapeRef intmd_shape,
236 const TensorOptions & options)
237{
238 return Tensor::empty(dynamic_shape, intmd_shape, const_base_sizes, options);
239}
240
241template <class Derived, Size... S>
242Derived
244{
245 return Tensor::zeros(const_base_sizes, options);
246}
247
248template <class Derived, Size... S>
249Derived
251 TensorShapeRef intmd_shape,
252 const TensorOptions & options)
253{
254 return Tensor::zeros(dynamic_shape, intmd_shape, const_base_sizes, options);
255}
256
257template <class Derived, Size... S>
258Derived
260{
261 return Tensor::ones(const_base_sizes, options);
262}
263
264template <class Derived, Size... S>
265Derived
267 TensorShapeRef intmd_shape,
268 const TensorOptions & options)
269{
270 return Tensor::ones(dynamic_shape, intmd_shape, const_base_sizes, options);
271}
272
273template <class Derived, Size... S>
274Derived
276{
277 return Tensor::full(const_base_sizes, init, options);
278}
279
280template <class Derived, Size... S>
281Derived
283 TensorShapeRef intmd_shape,
284 const CScalar & init,
285 const TensorOptions & options)
286{
287 return Tensor::full(dynamic_shape, intmd_shape, const_base_sizes, init, options);
288}
289
290template <class Derived, Size... S>
291Derived
293{
294 return Tensor::rand(const_base_sizes, options);
295}
296
297template <class Derived, Size... S>
298Derived
300 TensorShapeRef intmd_shape,
301 const TensorOptions & options)
302{
303 return Tensor::rand(dynamic_shape, intmd_shape, const_base_sizes, options);
304}
305
306template <class Derived, Size... S>
307Derived
308PrimitiveTensor<Derived, S...>::einsum(c10::string_view equation, TensorList tensors)
309{
310 const auto [tensors_aligned, i] = utils::align_intmd_dim(tensors);
311 std::vector<ATensor> tensors_einsum(tensors_aligned.size());
312 for (std::size_t j = 0; j < tensors_aligned.size(); ++j)
313 tensors_einsum[j] = tensors_aligned[j];
314 auto res = at::einsum(equation, tensors_einsum);
315 return Derived(res, i);
316}
317
318template <class Tuple, std::size_t... I>
319auto
320make_tensors(Tuple && t, std::index_sequence<I...>, const TensorOptions & options)
321{
322 return std::vector<neml2::Tensor>{
323 neml2::Tensor(at::scalar_to_tensor(std::get<I>(std::forward<Tuple>(t)), options.device())
324 .to(options.dtype()),
325 0)...};
326}
327
328template <class Derived, Size... S>
329template <typename... Args, typename>
330Derived
332{
333 if constexpr (sizeof...(Args) == const_base_numel)
334 {
335 if constexpr ((std::is_convertible_v<Args, neml2::Tensor> && ...))
336 {
337#ifndef NDEBUG
338 neml_assert_dbg(((args.base_dim() == 0) && ...),
339 "All input tensors must be scalar-like (no base dimensions)");
340#endif
341 return base_stack({std::forward<Args>(args)...}).base_reshape(const_base_sizes);
342 }
343 else if constexpr ((std::is_convertible_v<Args, CScalar> && ...))
344 {
345 auto t = neml2::Tensor::create({std::forward<Args>(args)...}, default_tensor_options());
346 return t.base_reshape(const_base_sizes);
347 }
348 }
349 else if constexpr (sizeof...(Args) == const_base_numel + 1)
350 {
351 auto tup = std::forward_as_tuple(std::forward<Args>(args)...);
352 const auto & options = std::get<sizeof...(Args) - 1>(tup);
353 auto vals = make_tensors(tup, std::make_index_sequence<sizeof...(Args) - 1>{}, options);
354 return base_stack(vals).base_reshape(const_base_sizes);
355 }
356
357 throw NEMLException("Invalid argument types to PrimitiveTensor::fill");
358}
359
360} // namespace neml2
Definition errors.h:34
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:51
static Derived full(const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:275
static Derived empty(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:234
static Derived empty(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:227
PrimitiveTensor(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition PrimitiveTensor.h:191
PrimitiveTensor(const ATensor &tensor, Size dynamic_dim, Size intmd_dim)
Construct from an ATensor and extract dynamic shape given dynamic dimension.
Definition PrimitiveTensor.h:165
static Derived zeros(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:250
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:57
static Derived rand(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:292
static Derived ones(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:259
Scalar operator()(Args... i) const
Single-element accessor.
Definition Scalar.h:53
std::integer_sequence< Size, S... > base_sizes_sequence
Base shape sequence.
Definition PrimitiveTensor.h:54
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:60
static Derived fill(Args &&... args)
Definition PrimitiveTensor.h:331
static Derived einsum(c10::string_view equation, TensorList tensors)
Einstein summation along base dimensions.
Definition PrimitiveTensor.h:308
static Derived ones(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape={}, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:266
static Derived rand(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:299
PrimitiveTensor(const ATensor &tensor, Size intmd_dim)
Construct from an ATensor and infer dynamic shape.
Definition PrimitiveTensor.h:158
void validate_shapes_and_dims() const
Validate shapes and dimensions.
Definition PrimitiveTensor.h:199
PrimitiveTensor(const ATensor &tensor, const TraceableTensorShape &dynamic_shape, Size intmd_dim)
Construct from an ATensor given dynamic shape.
Definition PrimitiveTensor.h:181
static Derived zeros(const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:243
PrimitiveTensor()=default
Special member functions.
static constexpr Size const_base_numel
The base numel.
Definition PrimitiveTensor.h:63
static Derived create(const TensorDataContainer &data, Size intmd_dim=0, const TensorOptions &options=default_tensor_options())
Arbitrary tensor from a nested container with inferred batch dimension.
Definition PrimitiveTensor.h:215
static Derived full(const TraceableTensorShape &dynamic_shape, TensorShapeRef intmd_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:282
Scalar.
Definition Scalar.h:38
NEML2's enhanced tensor type.
Definition TensorBase.h:77
Size dynamic_dim() const
Definition TensorBaseImpl.h:168
Size intmd_dim() const
Definition TensorBaseImpl.h:182
neml2::Tensor base_reshape(TensorShapeRef shape) const
Definition TensorBaseImpl.h:603
Definition Tensor.h:47
static Tensor full(TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:172
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:99
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:124
static Tensor create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary (unbatched) tensor from a nested container.
Definition Tensor.cxx:80
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:148
static Tensor rand(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Definition Tensor.cxx:197
std::pair< std::vector< Tensor >, Size > align_intmd_dim(TensorList tensors)
Definition utils.cxx:30
Definition DiagnosticsInterface.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:42
at::Tensor ATensor
Definition types.h:38
c10::ArrayRef< neml2::Tensor > TensorList
Definition Tensor.h:37
auto make_tensors(Tuple &&t, std::index_sequence< I... >, const TensorOptions &options)
Definition PrimitiveTensor.h:320
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
Tensor base_stack(const std::vector< Tensor > &tensors, Size d)
Definition stack.cxx:64
c10::TensorOptions TensorOptions
Definition types.h:60
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:42
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
Traceable tensor shape.
Definition TraceableTensorShape.h:38