NEML2 2.0.0
Loading...
Searching...
No Matches
shape_utils.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/types.h"
28#include "neml2/misc/errors.h"
29
30namespace neml2::utils
31{
42
53
61template <class... T>
62bool sizes_broadcastable(const T &... shapes);
63
70template <class... T>
71bool broadcastable(const T &... tensors);
72
77template <class... T>
78bool dynamic_broadcastable(const T &... tensors);
79
84template <class... T>
85bool intmd_broadcastable(const T &... tensors);
86
91template <class... T>
92bool base_broadcastable(const T &... tensors);
93
95template <class... T>
97
99template <class... T>
101
103template <class... T>
105
109template <class... T>
110TensorShape broadcast_sizes(const T &... shapes);
111
125
126template <typename... S>
128
138
139namespace details
140{
141template <typename... S>
142TensorShape add_shapes_impl(TensorShape &, TensorShapeRef, const S &...);
143} // namespace details
144} // namespace neml2::utils
145
147// Implementation
149
150namespace neml2::utils
151{
152template <class... T>
153bool
154sizes_broadcastable(const T &... shapes)
155{
156 auto dim = std::max({shapes.size()...});
157 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
158
159 for (size_t i = 0; i < dim; i++)
160 {
161 Size max_sz = 1;
162 for (const auto & s : all_shapes_padded)
163 {
164 if (max_sz == 1)
165 {
166#ifndef NDEBUG
167 if (s[i] <= 0)
168 throw NEMLException("Found a size equal or less than 0: " + std::to_string(s[i]));
169#endif
170 if (s[i] > 1)
171 max_sz = s[i];
172 }
173 else if (s[i] != 1 && s[i] != max_sz)
174 return false;
175 }
176 }
177
178 return true;
179}
180
181template <class... T>
182bool
183broadcastable(const T &... tensors)
184{
185 return dynamic_broadcastable(tensors...) && intmd_broadcastable(tensors...) &&
186 base_broadcastable(tensors...);
187}
188
189template <class... T>
190bool
191dynamic_broadcastable(const T &... tensors)
192{
193 return sizes_broadcastable(tensors.dynamic_sizes().concrete()...);
194}
195
196template <class... T>
197bool
198intmd_broadcastable(const T &... tensors)
199{
200 return sizes_broadcastable(tensors.intmd_sizes()...);
201}
202
203template <class... T>
204bool
205base_broadcastable(const T &... tensors)
206{
207 return sizes_broadcastable(tensors.base_sizes()...);
208}
209
210template <class... T>
211Size
212broadcast_dynamic_dim(const T &... tensor)
213{
214 return std::max({tensor.dynamic_dim()...});
215}
216
217template <class... T>
218Size
219broadcast_intmd_dim(const T &... tensor)
220{
221 return std::max({tensor.intmd_dim()...});
222}
223
224template <class... T>
225Size
226broadcast_base_dim(const T &... tensor)
227{
228 return std::max({tensor.base_dim()...});
229}
230
231template <class... T>
233broadcast_sizes(const T &... shapes)
234{
235#ifndef NDEBUG
236 if (!sizes_broadcastable(shapes...))
237 throw NEMLException("Shapes not broadcastable");
238#endif
239
240 auto dim = std::max({shapes.size()...});
241 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
242 auto bshape = TensorShape(dim, 1);
243
244 for (size_t i = 0; i < dim; i++)
245 for (const auto & s : all_shapes_padded)
246 if (s[i] > bshape[i])
247 bshape[i] = s[i];
248
249 return bshape;
250}
251
252template <typename... S>
254add_shapes(const S &... shape)
255{
256 TensorShape net;
257 return details::add_shapes_impl(net, shape...);
258}
259
260namespace details
261{
262template <typename... S>
264add_shapes_impl(TensorShape & net, TensorShapeRef s, const S &... rest)
265{
266 net.insert(net.end(), s.begin(), s.end());
267
268 if constexpr (sizeof...(rest) == 0)
269 return std::move(net);
270 else
271 return add_shapes_impl(net, rest...);
272}
273} // namespace details
274} // namespace neml2::utils
Definition errors.h:34
Definition Parser.cxx:36
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition shape_utils.cxx:71
bool intmd_broadcastable(const T &... tensors)
Definition shape_utils.h:198
TensorShape add_shapes(const S &...)
bool dynamic_broadcastable(const T &... tensors)
Definition shape_utils.h:191
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition shape_utils.h:233
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Definition shape_utils.cxx:49
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
Definition shape_utils.cxx:34
Size broadcast_dynamic_dim(const T &...)
The dynamic dimension after broadcasting.
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition shape_utils.h:154
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
Definition shape_utils.cxx:64
Size broadcast_base_dim(const T &...)
The base dimension after broadcasting.
Size broadcast_intmd_dim(const T &...)
The intermediate dimension after broadcasting.
bool broadcastable(const T &... tensors)
Definition shape_utils.h:183
bool base_broadcastable(const T &... tensors)
Definition shape_utils.h:205
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
int64_t Size
Definition types.h:65
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67