27#include "neml2/misc/types.h"
28#include "neml2/tensors/tensors_fwd.h"
34#define DECLARE_SUM_TO_SIZE(T) \
35 T dynamic_sum_to_size(const T & a, const TraceableTensorShape & shape); \
36 T intmd_sum_to_size(const T & a, TensorShapeRef shape);
37FOR_ALL_TENSORBASE(DECLARE_SUM_TO_SIZE);
38#undef DECLARE_SUM_TO_SIZE
Definition DiagnosticsInterface.h:31
Tensor base_sum_to_size(const Tensor &a, TensorShapeRef shape)
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73
Traceable tensor shape.
Definition TraceableTensorShape.h:38