NEML2 2.0.0
Loading...
Searching...
No Matches
shape_utils.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/types.h"
28#include "neml2/misc/errors.h"
29
30namespace neml2::utils
31{
37template <class... T>
38bool broadcastable(const T &... tensors);
39
44template <class... T>
45bool batch_broadcastable(const T &... tensors);
46
51template <class... T>
52bool base_broadcastable(const T &... tensors);
53
59template <class... T>
61
63template <class... T>
64bool sizes_same(T &&... shapes);
65
73template <class... T>
74bool sizes_broadcastable(const T &... shapes);
75
79template <class... T>
80TensorShape broadcast_sizes(const T &... shapes);
81
95
96template <typename... S>
97TensorShape add_shapes(const S &...);
98
108
109namespace details
110{
111template <typename... S>
112TensorShape add_shapes_impl(TensorShape &, TensorShapeRef, const S &...);
113} // namespace details
114} // namespace neml2::utils
115
117// Implementation
119
120namespace neml2::utils
121{
122template <class... T>
123bool
124broadcastable(const T &... tensors)
125{
126 if (!sizes_same(tensors.base_sizes()...))
127 return false;
128 return batch_broadcastable(tensors...);
129}
130
131template <class... T>
132bool
133batch_broadcastable(const T &... tensors)
134{
135 return sizes_broadcastable(tensors.batch_sizes().concrete()...);
136}
137
138template <class... T>
139bool
140base_broadcastable(const T &... tensors)
141{
142 return sizes_broadcastable(tensors.base_sizes()...);
143}
144
145template <class... T>
146Size
147broadcast_batch_dim(const T &... tensor)
148{
149 return std::max({tensor.batch_dim()...});
150}
151
152template <class... T>
153bool
154sizes_same(T &&... shapes)
155{
156 auto all_shapes = std::vector<TensorShapeRef>{std::forward<T>(shapes)...};
157 for (size_t i = 0; i < all_shapes.size() - 1; i++)
158 if (all_shapes[i] != all_shapes[i + 1])
159 return false;
160 return true;
161}
162
163template <class... T>
164bool
165sizes_broadcastable(const T &... shapes)
166{
167 auto dim = std::max({shapes.size()...});
168 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
169
170 for (size_t i = 0; i < dim; i++)
171 {
172 Size max_sz = 1;
173 for (const auto & s : all_shapes_padded)
174 {
175 if (max_sz == 1)
176 {
177#ifndef NDEBUG
178 if (s[i] <= 0)
179 throw NEMLException("Found a size equal or less than 0: " + std::to_string(s[i]));
180#endif
181 if (s[i] > 1)
182 max_sz = s[i];
183 }
184 else if (s[i] != 1 && s[i] != max_sz)
185 return false;
186 }
187 }
188
189 return true;
190}
191
192template <class... T>
194broadcast_sizes(const T &... shapes)
195{
196#ifndef NDEBUG
197 if (!sizes_broadcastable(shapes...))
198 throw NEMLException("Shapes not broadcastable");
199#endif
200
201 auto dim = std::max({shapes.size()...});
202 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
203 auto bshape = TensorShape(dim, 1);
204
205 for (size_t i = 0; i < dim; i++)
206 for (const auto & s : all_shapes_padded)
207 if (s[i] > bshape[i])
208 bshape[i] = s[i];
209
210 return bshape;
211}
212
213template <typename... S>
215add_shapes(const S &... shape)
216{
217 TensorShape net;
218 return details::add_shapes_impl(net, shape...);
219}
220
221namespace details
222{
223template <typename... S>
225add_shapes_impl(TensorShape & net, TensorShapeRef s, const S &... rest)
226{
227 net.insert(net.end(), s.begin(), s.end());
228
229 if constexpr (sizeof...(rest) == 0)
230 return std::move(net);
231 else
232 return add_shapes_impl(net, rest...);
233}
234} // namespace details
235} // namespace neml2::utils
Definition errors.h:34
Definition Parser.cxx:36
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:32
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition shape_utils.h:154
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition shape_utils.cxx:39
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
TensorShape add_shapes(const S &...)
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition shape_utils.h:194
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition shape_utils.h:165
bool batch_broadcastable(const T &... tensors)
Definition shape_utils.h:133
bool broadcastable(const T &... tensors)
Definition shape_utils.h:124
bool base_broadcastable(const T &... tensors)
Definition shape_utils.h:140
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
int64_t Size
Definition types.h:65
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67