27#include "neml2/misc/types.h"
28#include "neml2/misc/errors.h"
29#include "neml2/jit/types.h"
51template <
typename... S>
59template <
typename... S>
61add_traceable_shapes_impl(TraceableTensorShape &,
const TraceableTensorShape &,
const S &...);
72template <
typename... S>
77 return details::add_traceable_shapes_impl(net, shape...);
82template <
typename... S>
88 net.insert(net.end(), s.begin(), s.end());
90 if constexpr (
sizeof...(rest) == 0)
91 return std::move(net);
93 return add_traceable_shapes_impl(net, rest...);
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition utils.h:74
std::shared_ptr< jit::Graph > last_executed_optimized_graph()
Print last evaluated optimized graph.
Definition utils.cxx:78
TraceableTensorShape extract_batch_sizes(const ATensor &tensor, Size batch_dim)
Extract the batch shape of a tensor given batch dimension The extracted batch shape will be traceable...
Definition utils.cxx:62
Definition DiagnosticsInterface.cxx:30
void neml_assert_not_tracing_dbg()
Assert that we are currently NOT tracing (only effective in debug mode)
Definition utils.cxx:52
at::Tensor ATensor
Definition types.h:42
void neml_assert_tracing()
Assert that we are currently tracing.
Definition utils.cxx:32
void neml_assert_tracing_dbg()
Assert that we are currently tracing (only effective in debug mode)
Definition utils.cxx:44
int64_t Size
Definition types.h:69
void neml_assert_not_tracing()
Assert that we are currently NOT tracing.
Definition utils.cxx:38
Traceable tensor shape.
Definition TraceableTensorShape.h:38