28#include "neml2/tensors/assertions.h"
29#include "neml2/tensors/shape_utils.h"
30#include "neml2/tensors/Tensor.h"
34template <
typename T,
typename... Ts>
35std::tuple<T, Ts...,
Size>
40 if ((... && (a.intmd_dim() == ts.intmd_dim())))
41 return {a, ts..., a.intmd_dim()};
44 return {a.intmd_unsqueeze(0, dmax - a.intmd_dim()),
45 ts.intmd_unsqueeze(0, dmax - ts.intmd_dim())...,
51template <
typename T,
typename... Ts>
52std::tuple<T, Ts...,
Size>
57 if ((... && (a.intmd_dim() == ts.intmd_dim() && a.base_dim() == ts.base_dim())))
58 return {a, ts..., a.intmd_dim()};
62 return {a.intmd_unsqueeze(0, imax - a.intmd_dim()).base_unsqueeze(0, bmax - a.base_dim()),
63 ts.intmd_unsqueeze(0, imax - ts.intmd_dim()).base_unsqueeze(0, bmax - ts.base_dim())...,
std::pair< std::vector< Tensor >, Size > align_static_dim(TensorList tensors)
Definition utils.cxx:44
Size broadcast_base_dim(const T &...)
The base dimension after broadcasting.
std::pair< std::vector< Tensor >, Size > align_intmd_dim(TensorList tensors)
Definition utils.cxx:30
Size broadcast_intmd_dim(const T &...)
The intermediate dimension after broadcasting.
c10::ArrayRef< neml2::Tensor > TensorList
Definition Tensor.h:37
void neml_assert_intmd_broadcastable_dbg(const T &...)
int64_t Size
Definition types.h:65
void neml_assert_static_broadcastable_dbg(const T &...)