NEML2 2.0.0
Loading...
Searching...
No Matches
jit.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <torch/csrc/jit/frontend/tracer.h>
28#include <torch/csrc/jit/api/function_impl.h>
29#include <torch/csrc/jit/serialization/import.h>
30#include <torch/csrc/jit/api/function_impl.h>
31
32#include "neml2/misc/types.h"
33#include "neml2/tensors/TraceableTensorShape.h"
34
35namespace neml2::jit
36{
37using namespace torch::jit;
38}
39
40namespace neml2
41{
44
47
50
53
54namespace utils
55{
57TraceableSize traceable_numel(const TraceableTensorShape & shape);
58
61TraceableTensorShape extract_traceable_sizes(const ATensor & tensor, std::size_t n, std::size_t m);
62
63template <typename... S>
64TraceableTensorShape add_traceable_shapes(const S &... shape);
65
67std::shared_ptr<jit::Graph> last_executed_optimized_graph();
68
69namespace details
70{
71template <typename... S>
72TraceableTensorShape
73add_traceable_shapes_impl(TraceableTensorShape &, const TraceableTensorShape &, const S &...);
74} // namespace details
75} // namespace utils
76} // namespace neml2
77
79// Implementation
81
82namespace neml2::utils
83{
84template <typename... S>
85TraceableTensorShape
86add_traceable_shapes(const S &... shape)
87{
89 return details::add_traceable_shapes_impl(net, shape...);
90}
91
92namespace details
93{
94template <typename... S>
96add_traceable_shapes_impl(TraceableTensorShape & net,
97 const TraceableTensorShape & s,
98 const S &... rest)
99{
100 net.insert(net.end(), s.begin(), s.end());
101
102 if constexpr (sizeof...(rest) == 0)
103 return std::move(net);
104 else
105 return add_traceable_shapes_impl(net, rest...);
106}
107} // namespace details
108} // namespace neml2::utils
Definition BufferStore.h:43
Definition Parser.cxx:36
TraceableTensorShape extract_traceable_sizes(const ATensor &tensor, std::size_t n, std::size_t m)
Definition jit.cxx:68
TraceableTensorShape add_traceable_shapes(const S &... shape)
Definition jit.h:86
TraceableSize traceable_numel(const TraceableTensorShape &shape)
Get the number of elements in a tensor shape.
Definition jit.cxx:61
std::shared_ptr< jit::Graph > last_executed_optimized_graph()
Print last evaluated optimized graph.
Definition jit.cxx:83
Definition DiagnosticsInterface.cxx:30
void neml_assert_not_tracing_dbg()
Assert that we are currently NOT tracing (only effective in debug mode)
Definition jit.cxx:51
void neml_assert_tracing()
Assert that we are currently tracing.
Definition jit.cxx:31
void neml_assert_tracing_dbg()
Assert that we are currently tracing (only effective in debug mode)
Definition jit.cxx:43
void neml_assert_not_tracing()
Assert that we are currently NOT tracing.
Definition jit.cxx:37
Traceable tensor shape.
Definition TraceableTensorShape.h:38