NEML2 2.0.0
Loading...
Searching...
No Matches
PrimitiveTensor.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/tensors/Tensor.h"
28
29namespace neml2
30{
36template <class Derived, Size... S>
37class PrimitiveTensor : public TensorBase<Derived>
38{
39public:
41 static inline const TensorShape const_base_sizes = {S...};
42
44 static constexpr Size const_base_dim = sizeof...(S);
45
47 static inline const Size const_base_storage = utils::storage_size({S...});
48
50 PrimitiveTensor() = default;
51
53 explicit PrimitiveTensor(const torch::Tensor & tensor, Size batch_dim);
54
56 explicit PrimitiveTensor(const torch::Tensor & tensor, const TraceableTensorShape & batch_shape);
57
59 template <class Derived2>
61
63 PrimitiveTensor(const torch::Tensor & tensor);
64
66 operator Tensor() const;
67
69 [[nodiscard]] static Derived
70 empty(const torch::TensorOptions & options = default_tensor_options());
72 [[nodiscard]] static Derived
74 const torch::TensorOptions & options = default_tensor_options());
76 [[nodiscard]] static Derived
77 zeros(const torch::TensorOptions & options = default_tensor_options());
79 [[nodiscard]] static Derived
81 const torch::TensorOptions & options = default_tensor_options());
83 [[nodiscard]] static Derived
84 ones(const torch::TensorOptions & options = default_tensor_options());
86 [[nodiscard]] static Derived
88 const torch::TensorOptions & options = default_tensor_options());
90 [[nodiscard]] static Derived
91 full(Real init, const torch::TensorOptions & options = default_tensor_options());
93 [[nodiscard]] static Derived
95 Real init,
96 const torch::TensorOptions & options = default_tensor_options());
97
99 [[nodiscard]] static Tensor identity_map(const torch::TensorOptions &)
100 {
101 throw NEMLException("Not implemented");
102 }
103};
104
106// Implementations
108
109template <class Derived, Size... S>
110PrimitiveTensor<Derived, S...>::PrimitiveTensor(const torch::Tensor & tensor, Size batch_dim)
111 : TensorBase<Derived>(tensor, batch_dim)
112{
113 neml_assert_dbg(this->base_sizes() == const_base_sizes,
114 "Base shape mismatch: trying to create a tensor with base shape ",
116 " from a tensor with base shape ",
117 this->base_sizes());
118}
119
120template <class Derived, Size... S>
123 : TensorBase<Derived>(tensor, batch_shape)
124{
125 neml_assert_dbg(this->base_sizes() == const_base_sizes,
126 "Base shape mismatch: trying to create a tensor with base shape ",
128 " from a tensor with base shape ",
129 this->base_sizes());
130}
131
132template <class Derived, Size... S>
133template <class Derived2>
135 : TensorBase<Derived>(tensor)
136{
137 neml_assert_dbg(this->base_sizes() == const_base_sizes,
138 "Base shape mismatch: trying to create a tensor with base shape ",
140 " from a tensor with base shape ",
141 this->base_sizes());
142}
143
144template <class Derived, Size... S>
146 : TensorBase<Derived>(tensor, tensor.dim() - const_base_dim)
147{
148 neml_assert_dbg(this->base_sizes() == const_base_sizes,
149 "Base shape mismatch: trying to create a tensor with base shape ",
151 " from a tensor with shape ",
152 tensor.sizes());
153}
154
155template <class Derived, Size... S>
156PrimitiveTensor<Derived, S...>::operator Tensor() const
157{
158 return Tensor(*this, this->batch_sizes());
159}
160
161template <class Derived, Size... S>
162Derived
163PrimitiveTensor<Derived, S...>::empty(const torch::TensorOptions & options)
164{
165 return Tensor::empty(const_base_sizes, options);
166}
167
168template <class Derived, Size... S>
169Derived
171 const torch::TensorOptions & options)
172{
173 return Tensor::empty(batch_shape, const_base_sizes, options);
174}
175
176template <class Derived, Size... S>
177Derived
178PrimitiveTensor<Derived, S...>::zeros(const torch::TensorOptions & options)
179{
180 return Tensor::zeros(const_base_sizes, options);
181}
182
183template <class Derived, Size... S>
184Derived
186 const torch::TensorOptions & options)
187{
188 return Tensor::zeros(batch_shape, const_base_sizes, options);
189}
190
191template <class Derived, Size... S>
192Derived
193PrimitiveTensor<Derived, S...>::ones(const torch::TensorOptions & options)
194{
195 return Tensor::ones(const_base_sizes, options);
196}
197
198template <class Derived, Size... S>
199Derived
201 const torch::TensorOptions & options)
202{
203 return Tensor::ones(batch_shape, const_base_sizes, options);
204}
205
206template <class Derived, Size... S>
207Derived
208PrimitiveTensor<Derived, S...>::full(Real init, const torch::TensorOptions & options)
209{
210 return Tensor::full(const_base_sizes, init, options);
211}
212
213template <class Derived, Size... S>
214Derived
216 Real init,
217 const torch::TensorOptions & options)
218{
219 return Tensor::full(batch_shape, const_base_sizes, init, options);
220}
221} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
Definition error.h:33
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:38
static Derived full(Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:208
static Derived ones(const TraceableTensorShape &batch_shape, const torch::TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:200
PrimitiveTensor(const TensorBase< Derived2 > &tensor)
Copy constructor.
Definition PrimitiveTensor.h:134
PrimitiveTensor(const torch::Tensor &tensor)
Construct from another torch::Tensor and infer batch dimension.
Definition PrimitiveTensor.h:145
PrimitiveTensor(const torch::Tensor &tensor, Size batch_dim)
Construct from another torch::Tensor given batch dimension.
Definition PrimitiveTensor.h:110
static Derived empty(const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:163
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:41
static Derived zeros(const torch::TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:178
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:44
PrimitiveTensor(const torch::Tensor &tensor, const TraceableTensorShape &batch_shape)
Construct from another torch::Tensor given batch shape.
Definition PrimitiveTensor.h:121
static Derived ones(const torch::TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:193
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:47
static Tensor identity_map(const torch::TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:99
static Derived full(const TraceableTensorShape &batch_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:215
PrimitiveTensor()=default
Default constructor.
static Derived empty(const TraceableTensorShape &batch_shape, const torch::TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:170
static Derived zeros(const TraceableTensorShape &batch_shape, const torch::TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:185
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBaseImpl.h:158
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBaseImpl.h:192
Definition Tensor.h:47
static Tensor full(TensorShapeRef base_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition Tensor.cxx:170
static Tensor ones(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition Tensor.cxx:149
static Tensor empty(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition Tensor.cxx:107
static Tensor zeros(TensorShapeRef base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition Tensor.cxx:128
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
Definition CrossRef.cxx:31
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
torch::TensorOptions & default_tensor_options()
Definition types.cxx:157
double Real
Definition types.h:31
torch::SmallVector< Size > TensorShape
Definition types.h:34
int64_t Size
Definition types.h:33
Traceable tensor shape.
Definition types.h:81