Problem description
In NEML2, all tensor operations are traceable. The trace of an operation records the operator type, a stack of arguments and outputs, together with additional context. Multiple operations performed in sequence can be traced together into a graph representing the flow of data through operators. Such graph representation is primarily used for two purposes:
- Just-in-time (JIT) compilation and optimization of the operations
- Backward automatic-differentiation (AD)
This tutorial illustrates the utility of JIT compilation of NEML2 models, and a later tutorial demonstrates the use of backward AD to calculate parameter derivatives.
In this tutorial, let us consider the following problem
\begin{align} \dot{a} & = \dfrac{a - a_n}{t - t_n}, \label{1} \\
\dot{b} & = \dfrac{b - b_n}{t - t_n}, \label{2} \\
\dot{c} & = \dfrac{c - c_n}{t - t_n}, \label{3}
\end{align}
where the subscript \( n \) represents the variable value from the previous time step.
Model structure
All three equations can be translated to ScalarVariableRate. The input file looks like
[Models]
[eq1]
type = ScalarVariableRate
variable = 'a'
[]
[eq2]
type = ScalarVariableRate
variable = 'b'
[]
[eq3]
type = ScalarVariableRate
variable = 'c'
[]
[eq]
type = ComposedModel
models = 'eq1 eq2 eq3'
[]
[]
And the composed model correctly defines \( a \), \( a_n \), \( b \), \( b_n \), \( c \), \( c_n \), \( t \), \( t_n \) as input variables and \( \dot{a} \), \( \dot{b} \), \( \dot{c} \) as output variables.
C++
#include "neml2/neml2.h"
int
main()
{
std::cout << *model << std::endl;
}
Definition DiagnosticsInterface.h:31
std::shared_ptr< Model > load_model(const std::filesystem::path &path, const std::string &mname)
A convenient function to load an input file and get a model.
Output:
Name: eq
Input: a [Scalar]
a~1 [Scalar]
b [Scalar]
b~1 [Scalar]
c [Scalar]
c~1 [Scalar]
t [Scalar]
t~1 [Scalar]
Output: a_rate [Scalar]
b_rate [Scalar]
c_rate [Scalar]
Python
import neml2
print(model)
Output:
Name: eq
Input: a [Scalar]
a~1 [Scalar]
b [Scalar]
b~1 [Scalar]
c [Scalar]
c~1 [Scalar]
t [Scalar]
t~1 [Scalar]
Output: a_rate [Scalar]
b_rate [Scalar]
c_rate [Scalar]
Tracing
NEML2 enables tracing of tensor operations lazily. No tracing is performed when the model is first loaded from the input file. Tracing takes place when the model is being evaluated for the first time. The following code can be used to view the traced graph in text format.
C++
#include "neml2/neml2.h"
#include "neml2/tensors/Scalar.h"
#include "neml2/tensors/jit.h"
int
main()
{
model->value({{"a", a},
{"b", b},
{"c", c},
{"t", t},
{"a~1", a_n},
{"b~1", b_n},
{"c~1", c_n},
{"t~1", t_n}});
}
static Scalar full(const CScalar &init, const TensorOptions &options=default_tensor_options())
Definition PrimitiveTensor.h:276
std::shared_ptr< jit::Graph > last_executed_optimized_graph()
Print last evaluated optimized graph.
constexpr auto kFloat64
Definition types.h:54
void set_default_dtype(Dtype dtype)
Output:
graph(%eq::a : Tensor,
%eq::a~1 : Tensor,
%eq::b : Tensor,
%eq::b~1 : Tensor,
%eq::c : Tensor,
%eq::c~1 : Tensor,
%eq::t : Tensor,
%eq::t~1 : Tensor):
%8 : int = prim::Constant[value=1]()
%17 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::t)
%18 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::t~1)
%9 : Tensor = aten::sub(%17, %18, %8)
%19 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::a)
%20 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::a~1)
%10 : Tensor = aten::sub(%19, %20, %8)
%21 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%10)
%22 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%9)
%11 : Tensor = aten::div(%21, %22)
%23 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::b)
%24 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::b~1)
%12 : Tensor = aten::sub(%23, %24, %8)
%25 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%12)
%26 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%9)
%13 : Tensor = aten::div(%25, %26)
%27 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::c)
%28 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::c~1)
%14 : Tensor = aten::sub(%27, %28, %8)
%29 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%14)
%30 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%9)
%15 : Tensor = aten::div(%29, %30)
%31 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
%32 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%13)
%33 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%15)
%16 : Tensor[] = prim::ListConstruct(%31, %32, %33)
= prim::profile()
return (%16)
Python
import neml2
from neml2.tensors import Scalar
import torch
torch.set_default_dtype(torch.double)
a = Scalar.full(1.0)
b = Scalar.full(2.0)
c = Scalar.full(3.0)
t = Scalar.full(0.1)
a_n = Scalar.full(0.0)
b_n = Scalar.full(1.0)
c_n = Scalar.full(2.0)
t_n = Scalar.full(0.0)
model.value(
{
"a": a,
"b": b,
"c": c,
"t": t,
"a~1": a_n,
"b~1": b_n,
"c~1": c_n,
"t~1": t_n,
}
)
print(torch.jit.last_executed_optimized_graph())
Output:
graph(%eq::a : Tensor,
%eq::a~1 : Tensor,
%eq::b : Tensor,
%eq::b~1 : Tensor,
%eq::c : Tensor,
%eq::c~1 : Tensor,
%eq::t : Tensor,
%eq::t~1 : Tensor):
%8 : int = prim::Constant[value=1]() # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%17 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::t)
%18 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::t~1)
%9 : Tensor = aten::sub(%17, %18, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%19 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::a)
%20 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::a~1)
%10 : Tensor = aten::sub(%19, %20, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%21 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%10)
%22 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%9)
%11 : Tensor = aten::div(%21, %22) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%23 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::b)
%24 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::b~1)
%12 : Tensor = aten::sub(%23, %24, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%25 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%12)
%26 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%9)
%13 : Tensor = aten::div(%25, %26) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%27 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::c)
%28 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::c~1)
%14 : Tensor = aten::sub(%27, %28, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%29 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%14)
%30 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%9)
%15 : Tensor = aten::div(%29, %30) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex4.py:20:0
%31 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
%32 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%13)
%33 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%15)
%16 : Tensor[] = prim::ListConstruct(%31, %32, %33)
= prim::profile()
return (%16)
Note that the above graph is called a profiling graph. While it is not the most human-friendly to read, let us highlight some lines of the text output to try to associated it with the equations.
The following lines
%18 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::a)
%19 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::a~1)
%10 : Tensor = aten::sub(%18, %19, %8)
%20 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::t)
%21 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::t~1)
%11 : Tensor = aten::sub(%20, %21, %8)
%22 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%10)
%23 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
%12 : Tensor = aten::div(%22, %23)
cover three tensor operations: two aten::subs and one aten::div, which correspond to \( \eqref{1} \). Note that each variable is wrapped inside a profiling node denoted as prim::profile. These wrappers allow the graph executor to record and analyze the runtime statistics of tensor operations, in order to identify hot spots to optimize.
JIT optimization
With the profiling graph, further execution of the same traced graph automatically identifies opportunities for optimization. In summary, the following types of optimizations are enabled in NEML2 by default:
- Inlining
- Constant pooling
- Removing expands
- Canonicalization
- Dead code elimination
- Constant propagation
- Input shape propagation
- Common subexpression extraction
- Peephole optimization
- Loop unrolling
See the PyTorch JIT design document for detailed explanation on each of the optimization pass.
The code below shows that, after a few forward evaluations, the traced graph can be substantially optimized.
C++
#include "neml2/neml2.h"
#include "neml2/tensors/Scalar.h"
#include "neml2/tensors/jit.h"
int
main()
{
{"b", b},
{"c", c},
{"t", t},
{"a~1", a_n},
{"b~1", b_n},
{"c~1", c_n},
{"t~1", t_n}});
for (int i = 0; i < 10; i++)
model->value(inputs);
}
std::map< VariableName, Tensor > ValueMap
Definition Tensor.h:39
Output:
graph(%eq::a : Tensor,
%eq::a~1 : Tensor,
%eq::b : Tensor,
%eq::b~1 : Tensor,
%eq::c : Tensor,
%eq::c~1 : Tensor,
%eq::t : Tensor,
%eq::t~1 : Tensor):
%8 : int = prim::Constant[value=1]()
%11 : Tensor = aten::sub(%eq::t, %eq::t~1, %8)
%14 : Tensor = aten::sub(%eq::a, %eq::a~1, %8)
%17 : Tensor = aten::div(%14, %11)
%20 : Tensor = aten::sub(%eq::b, %eq::b~1, %8)
%23 : Tensor = aten::div(%20, %11)
%26 : Tensor = aten::sub(%eq::c, %eq::c~1, %8)
%29 : Tensor = aten::div(%26, %11)
%33 : Tensor[] = prim::ListConstruct(%17, %23, %29)
return (%33)
Python
import neml2
from neml2.tensors import Scalar
import torch
torch.set_default_dtype(torch.double)
a = Scalar.full(1.0)
b = Scalar.full(2.0)
c = Scalar.full(3.0)
t = Scalar.full(0.1)
a_n = Scalar.full(0.0)
b_n = Scalar.full(1.0)
c_n = Scalar.full(2.0)
t_n = Scalar.full(0.0)
inputs = {
"a": a,
"b": b,
"c": c,
"t": t,
"a~1": a_n,
"b~1": b_n,
"c~1": c_n,
"t~1": t_n,
}
for i in range(10):
model.value(inputs)
print(torch.jit.last_executed_optimized_graph())
Output:
graph(%eq::a : Tensor,
%eq::a~1 : Tensor,
%eq::b : Tensor,
%eq::b~1 : Tensor,
%eq::c : Tensor,
%eq::c~1 : Tensor,
%eq::t : Tensor,
%eq::t~1 : Tensor):
%8 : int = prim::Constant[value=1]() # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%11 : Tensor = aten::sub(%eq::t, %eq::t~1, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%14 : Tensor = aten::sub(%eq::a, %eq::a~1, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%17 : Tensor = aten::div(%14, %11) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%20 : Tensor = aten::sub(%eq::b, %eq::b~1, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%23 : Tensor = aten::div(%20, %11) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%26 : Tensor = aten::sub(%eq::c, %eq::c~1, %8) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%29 : Tensor = aten::div(%26, %11) # /home/runner/work/neml2/neml2/build/tutorials/models/just_in_time_compilation/ex6.py:31:0
%33 : Tensor[] = prim::ListConstruct(%17, %23, %29)
return (%33)
Note how the optimized graph successfully identifies the common subexpression \( t - t_n \) and reuses it in all three equations.
Limitations
JIT optimization and compilation isn't the holy grail for improving performance of all models. For tensor operations that branch based on variable data, the traced graph cannot capture such data dependency and would potentially produce wrong results. NEML2 is unable to generate traced graphs for models that include derivatives of other models in the forward evaluation when those derivatives are defined with automatic differentiation.
Due to these limitations, certain models disable the use of JIT compilation. The most notable case is ImplicitUpdate due to its use of Newton-Raphson solvers which are in general data dependent. However, the portions of the complete model defining the implicit function to solve can often benefit from JIT compilation.
When multiple models are composed together, a single function graph is by default traced through all sub-models. However, if one of the sub-model does not allow JIT, e.g., is of type ImplicitUpdate, then the composed model falls back to trace each individual sub-model except for those explicitly disabling JIT. Therefore, it is generally recommended to compose JIT-enabled sub-models separate from those JIT-disabled ones, allowing for more optimization opportunities.