NEML2 2.0.0
Loading...
Searching...
No Matches
PrimitiveTensor.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <torch/csrc/autograd/variable.h>
28#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
29
30#include "neml2/misc/errors.h"
31#include "neml2/tensors/Tensor.h"
32#include "neml2/tensors/shape_utils.h"
33
34namespace neml2
35{
36using TensorDataContainer = torch::detail::TensorDataContainer;
37
43template <class Derived, Size... S>
44class PrimitiveTensor : public TensorBase<Derived>
45{
46public:
48 static inline const TensorShape const_base_sizes = {S...};
49
51 static constexpr Size const_base_dim = sizeof...(S);
52
54 static inline const Size const_base_storage = utils::storage_size({S...});
55
57 PrimitiveTensor() = default;
58
61
63 PrimitiveTensor(const ATensor & tensor, const TraceableTensorShape & batch_shape);
64
66 PrimitiveTensor(const Tensor & tensor);
67
69 explicit PrimitiveTensor(const ATensor & tensor);
70
72 operator Tensor() const;
73
75 [[nodiscard]] static Derived create(const TensorDataContainer & data,
76 const TensorOptions & options = default_tensor_options());
78 [[nodiscard]] static Derived empty(const TensorOptions & options = default_tensor_options());
80 [[nodiscard]] static Derived empty(const TraceableTensorShape & batch_shape,
81 const TensorOptions & options = default_tensor_options());
83 [[nodiscard]] static Derived zeros(const TensorOptions & options = default_tensor_options());
85 [[nodiscard]] static Derived zeros(const TraceableTensorShape & batch_shape,
86 const TensorOptions & options = default_tensor_options());
88 [[nodiscard]] static Derived ones(const TensorOptions & options = default_tensor_options());
90 [[nodiscard]] static Derived ones(const TraceableTensorShape & batch_shape,
91 const TensorOptions & options = default_tensor_options());
93 [[nodiscard]] static Derived full(const CScalar & init,
94 const TensorOptions & options = default_tensor_options());
96 [[nodiscard]] static Derived full(const TraceableTensorShape & batch_shape,
97 const CScalar & init,
98 const TensorOptions & options = default_tensor_options());
99
101 [[nodiscard]] static Tensor identity_map(const TensorOptions &)
102 {
103 throw NEMLException("Not implemented");
104 }
105};
106
108// Implementations
110
111template <class Derived, Size... S>
113 : TensorBase<Derived>(tensor, batch_dim)
114{
115#ifndef NDEBUG
116 if (this->base_sizes() != const_base_sizes)
117 throw NEMLException("Base shape mismatch");
118#endif
119}
120
121template <class Derived, Size... S>
123 const TraceableTensorShape & batch_shape)
124 : TensorBase<Derived>(tensor, batch_shape)
125{
126#ifndef NDEBUG
127 if (this->base_sizes() != const_base_sizes)
128 throw NEMLException("Base shape mismatch");
129#endif
130}
131
132template <class Derived, Size... S>
134 : TensorBase<Derived>(tensor)
135{
136#ifndef NDEBUG
137 if (this->base_sizes() != const_base_sizes)
138 throw NEMLException("Base shape mismatch");
139#endif
140}
141
142template <class Derived, Size... S>
144 : TensorBase<Derived>(tensor, tensor.dim() - const_base_dim)
145{
146#ifndef NDEBUG
147 if (this->base_sizes() != const_base_sizes)
148 throw NEMLException("Base shape mismatch");
149#endif
150}
151
152template <class Derived, Size... S>
153PrimitiveTensor<Derived, S...>::operator Tensor() const
154{
155 return Tensor(*this, this->batch_sizes());
156}
157
158template <class Derived, Size... S>
159Derived
161 const TensorOptions & options)
162{
163 return Derived(torch::autograd::make_variable(
164 data.convert_to_tensor(options.requires_grad(false)), options.requires_grad()));
165}
166
167template <class Derived, Size... S>
168Derived
170{
171 return Tensor::empty(const_base_sizes, options);
172}
173
174template <class Derived, Size... S>
175Derived
177 const TensorOptions & options)
178{
179 return Tensor::empty(batch_shape, const_base_sizes, options);
180}
181
182template <class Derived, Size... S>
183Derived
185{
186 return Tensor::zeros(const_base_sizes, options);
187}
188
189template <class Derived, Size... S>
190Derived
192 const TensorOptions & options)
193{
194 return Tensor::zeros(batch_shape, const_base_sizes, options);
195}
196
197template <class Derived, Size... S>
198Derived
200{
201 return Tensor::ones(const_base_sizes, options);
202}
203
204template <class Derived, Size... S>
205Derived
207 const TensorOptions & options)
208{
209 return Tensor::ones(batch_shape, const_base_sizes, options);
210}
211
212template <class Derived, Size... S>
213Derived
215{
216 return Tensor::full(const_base_sizes, init, options);
217}
218
219template <class Derived, Size... S>
220Derived
222 const CScalar & init,
223 const TensorOptions & options)
224{
225 return Tensor::full(batch_shape, const_base_sizes, init, options);
226}
227} // namespace neml2
Definition errors.h:34
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:45
static Derived full(const CScalar &init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:214
static Derived ones(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:206
PrimitiveTensor(const ATensor &tensor)
Construct from another ATensor and infer batch dimension.
Definition PrimitiveTensor.h:143
static Tensor identity_map(const TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:101
static Derived empty(const TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:169
PrimitiveTensor(const ATensor &tensor, const TraceableTensorShape &batch_shape)
Construct from another ATensor given batch shape.
Definition PrimitiveTensor.h:122
static Derived create(const TensorDataContainer &data, const TensorOptions &options=default_tensor_options())
Arbitrary tensor from a nested container with inferred batch dimension.
Definition PrimitiveTensor.h:160
PrimitiveTensor(const Tensor &tensor)
Copy constructor.
Definition PrimitiveTensor.h:133
static Derived full(const TraceableTensorShape &batch_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:221
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:48
static Derived empty(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:176
static Derived ones(const TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:199
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:51
PrimitiveTensor(const ATensor &tensor, Size batch_dim)
Construct from another ATensor given batch dimension.
Definition PrimitiveTensor.h:112
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:54
static Derived zeros(const TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:184
static Derived zeros(const TraceableTensorShape &batch_shape, const TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:191
PrimitiveTensor()=default
Special member functions.
NEML2's enhanced tensor type.
Definition TensorBase.h:50
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:168
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:202
Definition Tensor.h:46
static Tensor full(TensorShapeRef base_shape, const CScalar &init, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:156
static Tensor empty(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:93
static Tensor zeros(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:114
static Tensor ones(TensorShapeRef base_shape, const TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:135
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:32
Definition DiagnosticsInterface.cxx:30
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
TensorOptions default_tensor_options()
Default floating point tensor options.
Definition defaults.cxx:42
int64_t Size
Definition types.h:65
c10::Scalar CScalar
Definition types.h:39
at::Tensor ATensor
Definition types.h:38
c10::TensorOptions TensorOptions
Definition types.h:60
torch::detail::TensorDataContainer TensorDataContainer
Definition PrimitiveTensor.h:36
Traceable tensor shape.
Definition TraceableTensorShape.h:38