NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
Just-in-time compilation

Problem description

In NEML2, all tensor operations are traceable. The trace of an operation records the operator type, a stack of arguments and outputs, together with additional context. Multiple operations performed in sequence can be traced together into a graph representing the flow of data through operators. Such graph representation is primarily used for two purposes:

  • Just-in-time (JIT) compilation and optimization of the operations
  • Backward automatic-differentiation (AD)

This tutorial illustrates the utility of JIT compilation of NEML2 models, and a later tutorial demonstrates the use of backward AD to calculate parameter derivatives.

In this tutorial, let us consider the following problem

(1)a˙=aanttn,(2)b˙=bbnttn,(3)c˙=ccnttn,

where the subscript n represents the variable value from the previous time step.

Model structure

All three equations can be translated to ScalarVariableValue. The input file looks like

[Models]
[eq1]
type = ScalarVariableRate
variable = 'state/a'
[]
[eq2]
type = ScalarVariableRate
variable = 'state/b'
[]
[eq3]
type = ScalarVariableRate
variable = 'state/c'
[]
[eq]
type = ComposedModel
models = 'eq1 eq2 eq3'
[]
[]

And the composed model correctly defines a, an, b, bn, c, cn, t, tn as input variables and a˙, b˙, c˙ as output variables.

  • #include "neml2/models/Model.h"
    int
    main()
    {
    using namespace neml2;
    auto & model = load_model("input.i", "eq");
    std::cout << model << std::endl;
    }
    Definition DiagnosticsInterface.cxx:30
    Model & load_model(const std::filesystem::path &path, const std::string &mname)
    A convenient function to load an input file and get a model.
    Definition Model.cxx:48

    Output:

    Name: eq
    Input: forces/t [Scalar]
    old_forces/t [Scalar]
    old_state/a [Scalar]
    old_state/b [Scalar]
    old_state/c [Scalar]
    state/a [Scalar]
    state/b [Scalar]
    state/c [Scalar]
    Output: state/a_rate [Scalar]
    state/b_rate [Scalar]
    state/c_rate [Scalar]
  • import neml2
    model = neml2.load_model("input.i", "eq")
    print(model)

    Output:

    Name: eq
    Input: forces/t [Scalar]
    old_forces/t [Scalar]
    old_state/a [Scalar]
    old_state/b [Scalar]
    old_state/c [Scalar]
    state/a [Scalar]
    state/b [Scalar]
    state/c [Scalar]
    Output: state/a_rate [Scalar]
    state/b_rate [Scalar]
    state/c_rate [Scalar]

Tracing

NEML2 enables tracing of tensor operations lazily. No tracing is performed when the model is first loaded from the input file. Tracing takes place when the model is being evaluated for the first time. The following code can be used to view the traced graph in text format.

  • #include "neml2/models/Model.h"
    #include "neml2/tensors/Scalar.h"
    #include "neml2/jit/utils.h"
    int
    main()
    {
    using namespace neml2;
    auto & model = load_model("input.i", "eq");
    // Create example input variables for tracing
    auto a = Scalar::full(1.0);
    auto b = Scalar::full(2.0);
    auto c = Scalar::full(3.0);
    auto t = Scalar::full(0.1);
    auto a_n = Scalar::full(0.0);
    auto b_n = Scalar::full(1.0);
    auto c_n = Scalar::full(2.0);
    auto t_n = Scalar::full(0.0);
    // Evaluate the model for the first time
    // This is when tracing takes place
    model.value({{VariableName("state", "a"), a},
    {VariableName("state", "b"), b},
    {VariableName("state", "c"), c},
    {VariableName("forces", "t"), t},
    {VariableName("old_state", "a"), a_n},
    {VariableName("old_state", "b"), b_n},
    {VariableName("old_state", "c"), c_n},
    {VariableName("old_forces", "t"), t_n}});
    }
    static Scalar full(Real init, const TensorOptions &options=default_tensor_options())
    Definition PrimitiveTensor.h:212
    std::shared_ptr< jit::Graph > last_executed_optimized_graph()
    Print last evaluated optimized graph.
    Definition utils.cxx:78
    constexpr auto kFloat64
    Definition types.h:53
    void set_default_dtype(Dtype dtype)
    Definition defaults.cxx:32
    LabeledAxisAccessor VariableName
    Definition LabeledAxisAccessor.h:185

    Output:

    graph(%eq::forces/t : Tensor,
    %eq::old_forces/t : Tensor,
    %eq::old_state/a : Tensor,
    %eq::old_state/b : Tensor,
    %eq::old_state/c : Tensor,
    %eq::state/a : Tensor,
    %eq::state/b : Tensor,
    %eq::state/c : Tensor):
    %8 : int = prim::Constant[value=1]()
    %9 : Tensor = prim::Constant[value= 1 1 1 [ CPULongType{3} ]]()
    %18 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::state/a)
    %19 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_state/a)
    %10 : Tensor = aten::sub(%18, %19, %8)
    %20 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::forces/t)
    %21 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_forces/t)
    %11 : Tensor = aten::sub(%20, %21, %8)
    %22 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%10)
    %23 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
    %12 : Tensor = aten::div(%22, %23)
    %24 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::state/b)
    %25 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_state/b)
    %13 : Tensor = aten::sub(%24, %25, %8)
    %26 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%13)
    %27 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
    %14 : Tensor = aten::div(%26, %27)
    %28 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::state/c)
    %29 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_state/c)
    %15 : Tensor = aten::sub(%28, %29, %8)
    %30 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%15)
    %31 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
    %16 : Tensor = aten::div(%30, %31)
    %32 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%12)
    %33 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%14)
    %34 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%16)
    %35 : Tensor = prim::profile[profiled_type=Long(3, strides=[1], requires_grad=0, device=cpu), seen_none=0](%9)
    %17 : Tensor[] = prim::ListConstruct(%32, %33, %34, %35)
    = prim::profile()
    return (%17)
  • import neml2
    from neml2.tensors import Scalar
    import torch
    torch.set_default_dtype(torch.double)
    model = neml2.load_model("input.i", "eq")
    # Create example input variables for tracing
    a = Scalar.full(1.0)
    b = Scalar.full(2.0)
    c = Scalar.full(3.0)
    t = Scalar.full(0.1)
    a_n = Scalar.full(0.0)
    b_n = Scalar.full(1.0)
    c_n = Scalar.full(2.0)
    t_n = Scalar.full(0.0)
    # Evaluate the model for the first time
    # This is when tracing takes place
    model.value({"state/a": a,
    "state/b": b,
    "state/c": c,
    "forces/t": t,
    "old_state/a": a_n,
    "old_state/b": b,
    "old_state/c": c,
    "old_forces/t": t})
    print(torch.jit.last_executed_optimized_graph())

    Output:

    graph(%eq::forces/t : Tensor,
    %eq::old_forces/t : Tensor,
    %eq::old_state/a : Tensor,
    %eq::old_state/b : Tensor,
    %eq::old_state/c : Tensor,
    %eq::state/a : Tensor,
    %eq::state/b : Tensor,
    %eq::state/c : Tensor):
    %8 : int = prim::Constant[value=1]() # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %9 : Tensor = prim::Constant[value= 1 1 1 [ CPULongType{3} ]]() # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %18 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::state/a)
    %19 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_state/a)
    %10 : Tensor = aten::sub(%18, %19, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %20 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::forces/t)
    %21 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_forces/t)
    %11 : Tensor = aten::sub(%20, %21, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %22 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%10)
    %23 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
    %12 : Tensor = aten::div(%22, %23) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %24 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::state/b)
    %25 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_state/b)
    %13 : Tensor = aten::sub(%24, %25, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %26 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%13)
    %27 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
    %14 : Tensor = aten::div(%26, %27) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %28 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::state/c)
    %29 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_state/c)
    %15 : Tensor = aten::sub(%28, %29, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %30 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%15)
    %31 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
    %16 : Tensor = aten::div(%30, %31) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src4.py:17:0
    %32 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%12)
    %33 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%14)
    %34 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%16)
    %35 : Tensor = prim::profile[profiled_type=Long(3, strides=[1], requires_grad=0, device=cpu), seen_none=0](%9)
    %17 : Tensor[] = prim::ListConstruct(%32, %33, %34, %35)
    = prim::profile()
    return (%17)

Note that the above graph is called a profiling graph. While it is not the most human-friendly to read, let us highlight some lines of the text output to try to associated it with the equations.

The following lines

%18 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::state/a)
%19 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_state/a)
%10 : Tensor = aten::sub(%18, %19, %8)
%20 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::forces/t)
%21 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%eq::old_forces/t)
%11 : Tensor = aten::sub(%20, %21, %8)
%22 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%10)
%23 : Tensor = prim::profile[profiled_type=Double(requires_grad=0, device=cpu), seen_none=0](%11)
%12 : Tensor = aten::div(%22, %23)

cover three tensor operations: two aten::subs and one aten::div, which correspond to (1). Note that each variable is wrapped inside a profiling node denoted as prim::profile. These wrappers allow the graph executor to record and analyze the runtime statistics of tensor operations, in order to identify hot spots to optimize.

JIT optimization

With the profiling graph, further execution of the same traced graph automatically identifies opportunities for optimization. In summary, the following types of optimizations are enabled in NEML2 by default:

  • Inlining
  • Constant pooling
  • Removing expands
  • Canonicalization
  • Dead code elimination
  • Constant propagation
  • Input shape propagation
  • Common subexpression extraction
  • Peephole optimization
  • Loop unrolling

See the PyTorch JIT design document for detailed explanation on each of the optimization pass.

The code below shows that, after a few forward evaluations, the traced graph can be substantially optimized.

  • #include "neml2/models/Model.h"
    #include "neml2/tensors/Scalar.h"
    #include "neml2/jit/utils.h"
    int
    main()
    {
    using namespace neml2;
    auto & model = load_model("input.i", "eq");
    // Create example input variables for tracing
    auto a = Scalar::full(1.0);
    auto b = Scalar::full(2.0);
    auto c = Scalar::full(3.0);
    auto t = Scalar::full(0.1);
    auto a_n = Scalar::full(0.0);
    auto b_n = Scalar::full(1.0);
    auto c_n = Scalar::full(2.0);
    auto t_n = Scalar::full(0.0);
    // Evaluate the model for multiple times
    auto inputs = ValueMap({{VariableName("state", "a"), a},
    {VariableName("state", "b"), b},
    {VariableName("state", "c"), c},
    {VariableName("forces", "t"), t},
    {VariableName("old_state", "a"), a_n},
    {VariableName("old_state", "b"), b_n},
    {VariableName("old_state", "c"), c_n},
    {VariableName("old_forces", "t"), t_n}});
    for (int i = 0; i < 10; i++)
    model.value(inputs);
    }
    std::map< LabeledAxisAccessor, Tensor > ValueMap
    Definition map_types_fwd.h:33

    Output:

    graph(%eq::forces/t : Tensor,
    %eq::old_forces/t : Tensor,
    %eq::old_state/a : Tensor,
    %eq::old_state/b : Tensor,
    %eq::old_state/c : Tensor,
    %eq::state/a : Tensor,
    %eq::state/b : Tensor,
    %eq::state/c : Tensor):
    %9 : Long(3, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 1 1 1 [ CPULongType{3} ]]()
    %8 : int = prim::Constant[value=1]()
    %12 : Tensor = aten::sub(%eq::state/a, %eq::old_state/a, %8)
    %15 : Tensor = aten::sub(%eq::forces/t, %eq::old_forces/t, %8)
    %18 : Tensor = aten::div(%12, %15)
    %21 : Tensor = aten::sub(%eq::state/b, %eq::old_state/b, %8)
    %24 : Tensor = aten::div(%21, %15)
    %27 : Tensor = aten::sub(%eq::state/c, %eq::old_state/c, %8)
    %30 : Tensor = aten::div(%27, %15)
    %35 : Tensor[] = prim::ListConstruct(%18, %24, %30, %9)
    return (%35)
  • import neml2
    from neml2.tensors import Scalar
    import torch
    torch.set_default_dtype(torch.double)
    model = neml2.load_model("input.i", "eq")
    # Create example input variables for tracing
    a = Scalar.full(1.0)
    b = Scalar.full(2.0)
    c = Scalar.full(3.0)
    t = Scalar.full(0.1)
    a_n = Scalar.full(0.0)
    b_n = Scalar.full(1.0)
    c_n = Scalar.full(2.0)
    t_n = Scalar.full(0.0)
    # Evaluate the model for the first time
    # This is when tracing takes place
    inputs = {"state/a": a,
    "state/b": b,
    "state/c": c,
    "forces/t": t,
    "old_state/a": a_n,
    "old_state/b": b,
    "old_state/c": c,
    "old_forces/t": t}
    for i in range(10):
    model.value(inputs)
    print(torch.jit.last_executed_optimized_graph())

    Output:

    graph(%eq::forces/t : Tensor,
    %eq::old_forces/t : Tensor,
    %eq::old_state/a : Tensor,
    %eq::old_state/b : Tensor,
    %eq::old_state/c : Tensor,
    %eq::state/a : Tensor,
    %eq::state/b : Tensor,
    %eq::state/c : Tensor):
    %9 : Long(3, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 1 1 1 [ CPULongType{3} ]]() # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %8 : int = prim::Constant[value=1]() # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %12 : Tensor = aten::sub(%eq::state/a, %eq::old_state/a, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %15 : Tensor = aten::sub(%eq::forces/t, %eq::old_forces/t, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %18 : Tensor = aten::div(%12, %15) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %21 : Tensor = aten::sub(%eq::state/b, %eq::old_state/b, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %24 : Tensor = aten::div(%21, %15) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %27 : Tensor = aten::sub(%eq::state/c, %eq::old_state/c, %8) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %30 : Tensor = aten::div(%27, %15) # /home/runner/work/neml2/neml2/build/dev/doc/content/tutorials/models/just_in_time_compilation/src6.py:26:0
    %35 : Tensor[] = prim::ListConstruct(%18, %24, %30, %9)
    return (%35)

Note how the optimized graph successfully identifies the common subexpression ttn and reuses it in all three equations.

Limitations

JIT optimization and compilation isn't the holy grail for improving performance of all models. For tensor operations that branch based on variable data, the traced graph cannot capture such data dependency and would potentially produce wrong results. NEML2 is unable to generate traced graphs for models that include derivatives of other models in the forward evaluation when those derivatives are defined with automatic differentiation.

Due to these limitations, certain models disable the use of JIT compilation. The most notable case is ImplicitUpdate due to its use of Newton-Raphson solvers which are in general data dependent. However, the portions of the complete model defining the implicit function to solve can often benefit from JIT compilation.

When multiple models are composed together, a single function graph is by default traced through all sub-models. However, if one of the sub-model does not allow JIT, e.g., is of type ImplicitUpdate, then the composed model falls back to trace each individual sub-model except for those explicit disabling JIT. Therefore, it is generally recommended to compose JIT-enabled sub-models separate from those JIT-disabled ones, allowing for more optimization opportunities.

Previous Next
Transient driver Models