27#include "neml2/misc/types.h"
28#include "neml2/misc/errors.h"
96template <
typename... S>
111template <
typename... S>
149 return std::max({tensor.batch_dim()...});
156 auto all_shapes = std::vector<TensorShapeRef>{std::forward<T>(shapes)...};
157 for (
size_t i = 0; i < all_shapes.size() - 1; i++)
158 if (all_shapes[i] != all_shapes[i + 1])
167 auto dim = std::max({shapes.size()...});
168 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
170 for (
size_t i = 0; i < dim; i++)
173 for (
const auto & s : all_shapes_padded)
179 throw NEMLException(
"Found a size equal or less than 0: " + std::to_string(s[i]));
184 else if (s[i] != 1 && s[i] != max_sz)
201 auto dim = std::max({shapes.size()...});
202 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
205 for (
size_t i = 0; i < dim; i++)
206 for (
const auto & s : all_shapes_padded)
207 if (s[i] > bshape[i])
213template <
typename... S>
218 return details::add_shapes_impl(net, shape...);
223template <
typename... S>
227 net.insert(net.end(), s.begin(), s.end());
229 if constexpr (
sizeof...(rest) == 0)
230 return std::move(net);
232 return add_shapes_impl(net, rest...);
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:32
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition shape_utils.h:154
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition shape_utils.cxx:39
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
TensorShape add_shapes(const S &...)
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition shape_utils.h:194
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition shape_utils.h:165
bool batch_broadcastable(const T &... tensors)
Definition shape_utils.h:133
bool broadcastable(const T &... tensors)
Definition shape_utils.h:124
bool base_broadcastable(const T &... tensors)
Definition shape_utils.h:140
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
int64_t Size
Definition types.h:65
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67