27#include "neml2/misc/types.h"
28#include "neml2/jit/TraceableTensorShape.h"
29#include <torch/csrc/jit/api/function_impl.h>
33using namespace torch::jit;
56template <
typename... S>
64template <
typename... S>
66add_traceable_shapes_impl(TraceableTensorShape &,
const TraceableTensorShape &,
const S &...);
77template <
typename... S>
82 return details::add_traceable_shapes_impl(net, shape...);
87template <
typename... S>
93 net.insert(net.end(), s.begin(), s.end());
95 if constexpr (
sizeof...(rest) == 0)
96 return std::move(net);
98 return add_traceable_shapes_impl(net, rest...);
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition utils.h:79
std::shared_ptr< jit::Graph > last_executed_optimized_graph()
Print last evaluated optimized graph.
Definition utils.cxx:78
TraceableTensorShape extract_batch_sizes(const ATensor &tensor, Size batch_dim)
Extract the batch shape of a tensor given batch dimension The extracted batch shape will be traceable...
Definition utils.cxx:62
Definition DiagnosticsInterface.cxx:30
void neml_assert_not_tracing_dbg()
Assert that we are currently NOT tracing (only effective in debug mode)
Definition utils.cxx:52
void neml_assert_tracing()
Assert that we are currently tracing.
Definition utils.cxx:32
void neml_assert_tracing_dbg()
Assert that we are currently tracing (only effective in debug mode)
Definition utils.cxx:44
int64_t Size
Definition types.h:65
void neml_assert_not_tracing()
Assert that we are currently NOT tracing.
Definition utils.cxx:38
Traceable tensor shape.
Definition TraceableTensorShape.h:38