27#include <torch/csrc/jit/frontend/tracer.h>
28#include <torch/csrc/jit/api/function_impl.h>
29#include <torch/csrc/jit/serialization/import.h>
30#include <torch/csrc/jit/api/function_impl.h>
32#include "neml2/misc/types.h"
33#include "neml2/tensors/TraceableTensorShape.h"
37using namespace torch::jit;
63template <
typename... S>
71template <
typename... S>
73add_traceable_shapes_impl(TraceableTensorShape &,
const TraceableTensorShape &,
const S &...);
84template <
typename... S>
89 return details::add_traceable_shapes_impl(net, shape...);
94template <
typename... S>
100 net.insert(net.end(), s.begin(), s.end());
102 if constexpr (
sizeof...(rest) == 0)
103 return std::move(net);
105 return add_traceable_shapes_impl(net, rest...);
Definition BufferStore.h:43
TraceableTensorShape extract_traceable_sizes(const ATensor &tensor, std::size_t n, std::size_t m)
Definition jit.cxx:68
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition jit.h:86
TraceableSize traceable_numel(const TraceableTensorShape &shape)
Get the number of elements in a tensor shape.
Definition jit.cxx:61
std::shared_ptr< jit::Graph > last_executed_optimized_graph()
Print last evaluated optimized graph.
Definition jit.cxx:83
Definition DiagnosticsInterface.cxx:30
void neml_assert_not_tracing_dbg()
Assert that we are currently NOT tracing (only effective in debug mode)
Definition jit.cxx:51
void neml_assert_tracing()
Assert that we are currently tracing.
Definition jit.cxx:31
void neml_assert_tracing_dbg()
Assert that we are currently tracing (only effective in debug mode)
Definition jit.cxx:43
void neml_assert_not_tracing()
Assert that we are currently NOT tracing.
Definition jit.cxx:37
Traceable tensor shape.
Definition TraceableTensorShape.h:38