28#include "neml2/tensors/indexing.h"
29#include "neml2/misc/errors.h"
97template <
typename... S>
112template <
typename... S>
150 return std::max({tensor.batch_dim()...});
157 auto all_shapes = std::vector<TensorShapeRef>{std::forward<T>(shapes)...};
158 for (
size_t i = 0; i < all_shapes.size() - 1; i++)
159 if (all_shapes[i] != all_shapes[i + 1])
168 auto dim = std::max({shapes.size()...});
169 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
171 for (
size_t i = 0; i < dim; i++)
174 for (
const auto & s : all_shapes_padded)
180 throw NEMLException(
"Found a size equal or less than 0: " + std::to_string(s[i]));
185 else if (s[i] != 1 && s[i] != max_sz)
202 auto dim = std::max({shapes.size()...});
203 auto all_shapes_padded = std::vector<TensorShape>{
pad_prepend(shapes, dim)...};
206 for (
size_t i = 0; i < dim; i++)
207 for (
const auto & s : all_shapes_padded)
208 if (s[i] > bshape[i])
214template <
typename... S>
219 return details::add_shapes_impl(net, shape...);
224template <
typename... S>
228 net.insert(net.end(), s.begin(), s.end());
230 if constexpr (
sizeof...(rest) == 0)
231 return std::move(net);
233 return add_shapes_impl(net, rest...);
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition shape_utils.cxx:30
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition shape_utils.h:155
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition shape_utils.cxx:37
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
TensorShape add_shapes(const S &...)
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition shape_utils.h:195
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition shape_utils.h:166
bool batch_broadcastable(const T &... tensors)
Definition shape_utils.h:134
bool broadcastable(const T &... tensors)
Definition shape_utils.h:125
bool base_broadcastable(const T &... tensors)
Definition shape_utils.h:141
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:71
int64_t Size
Definition types.h:69
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:72