NEML2 2.0.0
Loading...
Searching...
No Matches
utils.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/types.h"
28#include "neml2/misc/error.h"
29
30namespace neml2
31{
32
38template <class... T>
39bool broadcastable(const T &... tensors);
40
46template <class... T>
48
56template <class... T>
57void neml_assert_broadcastable(const T &...);
58
66template <class... T>
68
76template <class... T>
78
86template <class... T>
88
89namespace utils
90{
92std::string demangle(const char * name);
93
95template <class... T>
96bool sizes_same(T &&... shapes);
97
105template <class... T>
106bool sizes_broadcastable(const T &... shapes);
107
111template <class... T>
113
116TraceableTensorShape extract_batch_sizes(const torch::Tensor & tensor, Size batch_dim);
117
131
132template <typename... S>
134
135template <typename... S>
137
147torch::Tensor pad_prepend(const torch::Tensor & s, Size dim, Size pad = 1);
148
149std::string indentation(int level, int indent = 2);
150
151template <typename T>
152std::string stringify(const T & t);
153
154namespace details
155{
156template <typename... S>
157TensorShape add_shapes_impl(TensorShape &, TensorShapeRef, S &&...);
158TensorShape add_shapes_impl(TensorShape &);
159
160template <typename... S>
162add_traceable_shapes_impl(TraceableTensorShape &, const TraceableTensorShape &, S &&...);
163TraceableTensorShape add_traceable_shapes_impl(TraceableTensorShape &);
164} // namespace details
165} // namespace utils
166} // namespace neml2
167
169// Implementation
171
172namespace neml2
173{
174template <class... T>
175bool
177{
178 if (!utils::sizes_same(tensors.base_sizes()...))
179 return false;
180 return utils::sizes_broadcastable(tensors.batch_sizes().concrete()...);
181}
182
183template <class... T>
184Size
185broadcast_batch_dim(const T &... tensor)
186{
187 return std::max({tensor.batch_dim()...});
188}
189
190template <class... T>
191void
193{
195 "The ",
196 sizeof...(tensors),
197 " operands are not broadcastable. The batch shapes are ",
198 tensors.batch_sizes()...,
199 ", and the base shapes are ",
200 tensors.base_sizes()...);
201}
202
203template <class... T>
204void
206{
207#ifndef NDEBUG
209 "The ",
210 sizeof...(tensors),
211 " operands are not broadcastable. The batch shapes are ",
212 tensors.batch_sizes()...,
213 ", and the base shapes are ",
214 tensors.base_sizes()...);
215#endif
216}
217
218template <class... T>
219void
221{
222 neml_assert(utils::sizes_broadcastable(tensors.batch_sizes().concrete()...),
223 "The ",
224 sizeof...(tensors),
225 " operands are not batch-broadcastable. The batch shapes are ",
226 tensors.batch_sizes()...);
227}
228
229template <class... T>
230void
232{
233#ifndef NDEBUG
234 neml_assert_dbg(utils::sizes_broadcastable(tensors.batch_sizes().concrete()...),
235 "The ",
236 sizeof...(tensors),
237 " operands are not batch-broadcastable. The batch shapes are ",
238 tensors.batch_sizes()...);
239#endif
240}
241
242namespace utils
243{
244template <class... T>
245bool
247{
248 auto all_shapes = std::vector<TensorShapeRef>{shapes...};
249 for (size_t i = 0; i < all_shapes.size() - 1; i++)
250 if (all_shapes[i] != all_shapes[i + 1])
251 return false;
252 return true;
253}
254
255template <class... T>
256bool
258{
259 auto dim = std::max({shapes.size()...});
260 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
261
262 for (size_t i = 0; i < dim; i++)
263 {
264 Size max_sz = 1;
265 for (const auto & s : all_shapes_padded)
266 {
267 if (max_sz == 1)
268 {
269 neml_assert_dbg(s[i] > 0, "Found a size equal or less than 0.");
270 if (s[i] > 1)
271 max_sz = s[i];
272 }
273 else if (s[i] != 1 && s[i] != max_sz)
274 return false;
275 }
276 }
277
278 return true;
279}
280
281template <class... T>
284{
285 neml_assert_dbg(sizes_broadcastable(shapes...), "Shapes not broadcastable: ", shapes...);
286
287 auto dim = std::max({shapes.size()...});
288 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
289 auto bshape = TensorShape(dim, 1);
290
291 for (size_t i = 0; i < dim; i++)
292 for (const auto & s : all_shapes_padded)
293 if (s[i] > bshape[i])
294 bshape[i] = s[i];
295
296 return bshape;
297}
298
299template <typename... S>
302{
304 return details::add_shapes_impl(net, std::forward<S>(shape)...);
305}
306
307template <typename... S>
310{
312 return details::add_traceable_shapes_impl(net, std::forward<S>(shape)...);
313}
314
315template <typename T>
316std::string
317stringify(const T & t)
318{
319 std::ostringstream os;
320 os << t;
321 return os.str();
322}
323
324template <>
325inline std::string
326stringify(const bool & t)
327{
328 return t ? "true" : "false";
329}
330
331namespace details
332{
333template <typename... S>
335add_shapes_impl(TensorShape & net, TensorShapeRef s, S &&... rest)
336{
337 net.insert(net.end(), s.begin(), s.end());
338 return add_shapes_impl(net, std::forward<S>(rest)...);
339}
340
341template <typename... S>
342TraceableTensorShape
343add_traceable_shapes_impl(TraceableTensorShape & net, const TraceableTensorShape & s, S &&... rest)
344{
345 net.insert(net.end(), s.begin(), s.end());
346 return add_traceable_shapes_impl(net, std::forward<S>(rest)...);
347}
348} // namespace details
349} // namespace utils
350} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition utils.h:246
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition utils.cxx:62
TensorShape add_shapes(S &&... shape)
Definition utils.h:301
TraceableTensorShape add_traceable_shapes(S &&... shape)
Definition utils.h:309
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition utils.h:283
TraceableTensorShape extract_batch_sizes(const torch::Tensor &tensor, Size batch_dim)
Extract the batch shape of a tensor given batch dimension The extracted batch shape will be traceable...
Definition utils.cxx:39
std::string stringify(const T &t)
Definition utils.h:317
std::string indentation(int level, int indent)
Definition utils.cxx:80
std::string demangle(const char *name)
Demangle a piece of cxx abi type information.
Definition utils.cxx:32
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition utils.h:257
Definition CrossRef.cxx:31
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
void neml_assert_batch_broadcastable(const T &...)
A helper function to assert that all tensors are batch-broadcastable.
void neml_assert_batch_broadcastable_dbg(const T &...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
torch::SmallVector< Size > TensorShape
Definition types.h:34
void neml_assert_broadcastable(const T &...)
A helper function to assert that all tensors are broadcastable.
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35
bool broadcastable(const T &... tensors)
Definition utils.h:176
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64
Traceable tensor shape.
Definition types.h:81