27#include "neml2/misc/types.h"
28#include "neml2/misc/errors.h"
126template <
typename... S>
141template <
typename... S>
156 auto dim = std::max({shapes.size()...});
157 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
159 for (
size_t i = 0; i < dim; i++)
162 for (
const auto & s : all_shapes_padded)
168 throw NEMLException(
"Found a size equal or less than 0: " + std::to_string(s[i]));
173 else if (s[i] != 1 && s[i] != max_sz)
214 return std::max({tensor.dynamic_dim()...});
221 return std::max({tensor.intmd_dim()...});
228 return std::max({tensor.base_dim()...});
240 auto dim = std::max({shapes.size()...});
241 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
244 for (
size_t i = 0; i < dim; i++)
245 for (
const auto & s : all_shapes_padded)
246 if (s[i] > bshape[i])
252template <
typename... S>
257 return details::add_shapes_impl(net, shape...);
262template <
typename... S>
266 net.insert(net.end(), s.begin(), s.end());
268 if constexpr (
sizeof...(rest) == 0)
269 return std::move(net);
271 return add_shapes_impl(net, rest...);
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition shape_utils.cxx:71
bool intmd_broadcastable(const T &... tensors)
Definition shape_utils.h:198
TensorShape add_shapes(const S &...)
bool dynamic_broadcastable(const T &... tensors)
Definition shape_utils.h:191
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition shape_utils.h:233
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Definition shape_utils.cxx:49
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
Definition shape_utils.cxx:34
Size broadcast_dynamic_dim(const T &...)
The dynamic dimension after broadcasting.
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition shape_utils.h:154
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
Definition shape_utils.cxx:64
Size broadcast_base_dim(const T &...)
The base dimension after broadcasting.
Size broadcast_intmd_dim(const T &...)
The intermediate dimension after broadcasting.
bool broadcastable(const T &... tensors)
Definition shape_utils.h:183
bool base_broadcastable(const T &... tensors)
Definition shape_utils.h:205
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
int64_t Size
Definition types.h:65
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67