27#include "neml2/misc/types.h"
28#include "neml2/misc/errors.h"
38template <
typename T, std::
size_t N>
39static T sum_array(
const std::array<T, N> & arr);
157template <
typename... S>
178template <
typename... S>
189template <
typename T, std::
size_t N>
191sum_array(
const std::array<T, N> & arr)
193 return std::accumulate(arr.begin(), arr.end(), T(0), [](T sum, T x) {
return sum + x; });
200 auto dim = std::max({shapes.size()...});
201 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
203 for (std::size_t i = 0; i < dim; i++)
206 for (
const auto & s : all_shapes_padded)
212 throw NEMLException(
"Found a size equal or less than 0: " + std::to_string(s[i]));
217 else if (s[i] != 1 && s[i] != max_sz)
258 return std::max({tensor.dynamic_dim()...});
265 return std::max({tensor.intmd_dim()...});
272 return std::max({tensor.base_dim()...});
284 auto dim = std::max({shapes.size()...});
285 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
288 for (std::size_t i = 0; i < dim; i++)
289 for (
const auto & s : all_shapes_padded)
290 if (s[i] > bshape[i])
296template <
typename... S>
301 return details::add_shapes_impl(net, shape...);
306template <
typename... S>
310 net.insert(net.end(), s.begin(), s.end());
312 if constexpr (
sizeof...(rest) == 0)
313 return std::move(net);
315 return add_shapes_impl(net, rest...);
std::vector< TensorShape > shape_refs_to_shapes(const std::vector< TensorShapeRef > &)
std::vector< TensorShapeRef > shapes_to_shape_refs(const std::vector< TensorShape > &)
bool intmd_broadcastable(const T &... tensors)
Definition shape_utils.h:242
TensorShape add_shapes(const S &...)
bool dynamic_broadcastable(const T &... tensors)
Definition shape_utils.h:235
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition shape_utils.h:277
Size normalize_itr(Size d, Size dl, Size du)
Helper function to normalize a iterator-like index to be non-negative given the lower- and upper-boun...
Size normalize_dim(Size d, Size dl, Size du)
Helper function to normalize a dimension index to be non-negative given the lower- and upper-bound of...
TensorShape pad_prepend(TensorShapeRef s, std::size_t dim, Size pad=1)
Pad shape s to dimension dim by prepending sizes of pad.
TensorShape normalize_dims(ArrayRef< Size > d, Size dl, Size du)
Helper function to normalize multiple dimension indices to be non-negative given the lower- and upper...
Size broadcast_dynamic_dim(const T &...)
The dynamic dimension after broadcasting.
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition shape_utils.h:198
TensorShape normalize_itrs(ArrayRef< Size > d, Size dl, Size du)
Helper function to normalize multiple iterator-like indices to be non-negative given the lower- and u...
Size numel(TensorShapeRef shape)
Number of elements in a tensor with given shape.
Size broadcast_base_dim(const T &...)
The base dimension after broadcasting.
Size broadcast_intmd_dim(const T &...)
The intermediate dimension after broadcasting.
bool broadcastable(const T &... tensors)
Definition shape_utils.h:227
bool base_broadcastable(const T &... tensors)
Definition shape_utils.h:249
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
c10::ArrayRef< T > ArrayRef
Definition types.h:63
int64_t Size
Definition types.h:71
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73