27#include "neml2/misc/types.h"
28#include "neml2/misc/error.h"
132template <
typename...
S>
135template <
typename...
S>
156template <
typename...
S>
160template <
typename...
S>
187 return std::max({tensor.batch_dim()...});
197 " operands are not broadcastable. The batch shapes are ",
199 ", and the base shapes are ",
211 " operands are not broadcastable. The batch shapes are ",
213 ", and the base shapes are ",
225 " operands are not batch-broadcastable. The batch shapes are ",
237 " operands are not batch-broadcastable. The batch shapes are ",
259 auto dim = std::max({
shapes.size()...});
262 for (
size_t i = 0;
i < dim;
i++)
287 auto dim = std::max({
shapes.size()...});
291 for (
size_t i = 0;
i < dim;
i++)
299template <
typename...
S>
304 return details::add_shapes_impl(
net, std::forward<S>(
shape)...);
307template <
typename...
S>
312 return details::add_traceable_shapes_impl(
net, std::forward<S>(
shape)...);
319 std::ostringstream
os;
328 return t ?
"true" :
"false";
333template <
typename...
S>
337 net.insert(
net.end(),
s.begin(),
s.end());
338 return add_shapes_impl(
net, std::forward<S>(
rest)...);
341template <
typename... S>
343add_traceable_shapes_impl(TraceableTensorShape & net,
const TraceableTensorShape & s, S &&... rest)
345 net.insert(net.end(), s.begin(), s.end());
346 return add_traceable_shapes_impl(net, std::forward<S>(rest)...);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition utils.h:246
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition utils.cxx:62
TensorShape add_shapes(S &&... shape)
Definition utils.h:301
TraceableTensorShape add_traceable_shapes(S &&... shape)
Definition utils.h:309
TensorShape broadcast_sizes(const T &... shapes)
Return the broadcast shape of all the shapes.
Definition utils.h:283
TraceableTensorShape extract_batch_sizes(const torch::Tensor &tensor, Size batch_dim)
Extract the batch shape of a tensor given batch dimension The extracted batch shape will be traceable...
Definition utils.cxx:39
std::string stringify(const T &t)
Definition utils.h:317
std::string indentation(int level, int indent)
Definition utils.cxx:80
std::string demangle(const char *name)
Demangle a piece of cxx abi type information.
Definition utils.cxx:32
bool sizes_broadcastable(const T &... shapes)
Check if the shapes are broadcastable.
Definition utils.h:257
Definition CrossRef.cxx:31
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
void neml_assert_batch_broadcastable(const T &...)
A helper function to assert that all tensors are batch-broadcastable.
void neml_assert_batch_broadcastable_dbg(const T &...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
torch::SmallVector< Size > TensorShape
Definition types.h:34
void neml_assert_broadcastable(const T &...)
A helper function to assert that all tensors are broadcastable.
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35
bool broadcastable(const T &... tensors)
Definition utils.h:176
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64
Traceable tensor shape.
Definition types.h:81