Dispatching across devices¶
Modern compute nodes are heterogeneous: one or more CPUs alongside one or more GPUs. The work scheduler / dispatcher lets a single compiled model spread a large batched evaluation across devices — slice the batch into sub-batches, run each on a device, and stitch the results back together.
This is a C++ runtime feature: it serves the compiled
(AOTI packages) path embedded in a host application (e.g. MOOSE), where the
hot loop runs without Python. The Python authoring path (neml2.load_model,
neml2-run) stays eager and single-device — to spread work from Python, run
your own per-device loop.
Note
Only CPU and CUDA devices are supported. Two scheduling modes are available:
synchronous single-device (SimpleScheduler, MPISimpleScheduler) and
asynchronous multi-device (StaticHybridScheduler, which runs CPU + GPU(s)
concurrently via a thread-per-device pool).
Compile one artifact per device¶
AOTI graphs are pinned to a device at export time, so a dispatcher that targets
several devices needs one artifact per device. neml2-compile takes multiple
--device values and emits them side by side:
# forward_single: a one-leaf forward-only model. Smoke test for the
# bake-by-default path on the simplest possible shape.
[Models]
[model]
type = LinearIsotropicElasticity
strain = 'strain'
stress = 'stress'
coefficients = '100 0.3'
coefficient_types = 'YOUNGS_MODULUS POISSONS_RATIO'
[]
[]
$ neml2-compile tests/aoti/forward_single/model.i --model model --device cpu cuda -d :
(-d : compiles every Jacobian/JVP pair so the dispatched jacobian /
jvp work; drop it for a forward-only artifact, or name specific pairs.)
A batch-independent block is returned unbatched and the dispatcher passes
it through unchanged (it is identical across batch chunks).
produces a standalone stub next to a per-device artifact folder:
aoti/
model_aoti.i # standalone stub; points at the folder via artifact_path
model/
cpu/ model_meta.json + *.pt2
cuda/ model_meta.json + *.pt2
The loader resolves <artifact_path>/<device>/ for the device it runs on. The
Python shim picks the subfolder matching torch.get_default_device() (so
neml2-run --device cuda loads cuda/); the C++ loader picks it from the
scheduler (below).
Load and dispatch from C++¶
neml2::aoti::load_model mirrors Python’s load_model(path, name) and returns a
DispatchedModel — a Model-shaped handle exposing the same
forward / jvp / jacobian. The optional scheduler is the dispatch opt-in; it
is supplied in C++ source, never from the .i.
#include "neml2/csrc/dispatchers/factory.h"
#include "neml2/csrc/dispatchers/SimpleScheduler.h"
using namespace neml2::aoti;
// No scheduler -> no dispatch: runs the whole batch on cpu (zero-overhead
// pass-through over the underlying Model).
auto m = load_model("aoti/model_aoti.i", "model");
auto out = m.forward(inputs);
// With a scheduler -> the batch is chunked along its leading axis, each chunk
// moved to the compute device, run, and the results concatenated back on the
// input device.
auto sched = std::make_shared<SimpleScheduler>(SimpleScheduler::Config{"cuda", 4096});
auto md = load_model("aoti/model_aoti.i", "model", sched);
auto out_gpu = md.forward(inputs); // cpu inputs -> gpu compute -> cpu results
When the scheduler’s device equals the input device and the whole batch fits in
one chunk, DispatchedModel short-circuits to a direct Model call — so the
no-dispatch case carries no slicing or transfer cost.
Schedulers¶
A scheduler decides which device(s) a workload runs on and how large each
sub-batch chunk is. All are plain C++ objects configured by a Config struct.
DispatchedModel picks its execution mode from the scheduler’s type: a
synchronous scheduler (SimpleScheduler, MPISimpleScheduler) runs the
chunk loop on the calling thread; an asynchronous one
(StaticHybridScheduler) drives a thread-per-device pool.
SimpleScheduler¶
Sends the whole workload to a single device, chunked. Config{device, batch_size} (e.g. {"cuda:0", 1024}; batch_size = 0 means “no chunking”). Use
it to:
process a batch too large for device memory in fixed-size pieces;
empirically tune the per-call batch size for a model + device; or
drive one device per process when a host pins devices by hand.
Note
Illustrative schematic: the batch is split into fixed-size chunks and fed to one device, a new chunk dispatched as in-flight ones finish and free capacity.
MPISimpleScheduler¶
For MPI jobs running one rank per GPU. Config{devices, batch_sizes} lists the
CUDA devices to choose from; each rank is assigned one based on its rank within
its node (ranks are grouped by hostname, then the local rank indexes into the
list), after which it chunks exactly like SimpleScheduler. Requires NEML2 built
with -DNEML2_MPI=ON and at least one device per rank per node; otherwise the
constructor throws.
StaticHybridScheduler¶
Spreads one batch across several devices concurrently — a single
DispatchedModel runs CPU + GPU(s) at once via a thread-per-device pool.
Config{devices, batch_sizes, capacities, priorities} (the last two optional;
each broadcasts from length 1). Assignment is greedy: each chunk goes to the
highest-priority device that still has spare capacity
(load + batch_size <= capacity), so faster devices stay filled; capacity
controls how many chunks may be in flight per device (overlapping the next
chunk’s host→device copy with the current chunk’s compute).
#include "neml2/csrc/dispatchers/StaticHybridScheduler.h"
StaticHybridScheduler::Config cfg;
cfg.devices = {"cpu", "cuda:0", "cuda:1"};
cfg.batch_sizes = {512, 4096, 4096}; // tune per device, e.g. via SimpleScheduler
auto m = load_model("aoti/model_aoti.i", "model",
std::make_shared<StaticHybridScheduler>(cfg));
auto out = m.forward(inputs); // dispatched across all three, gathered back
Note
Illustrative schematic: each chunk goes to the highest-priority device that
still has spare capacity, so faster devices stay filled; chunks process
concurrently across devices and are gathered back as they finish.
A hybrid pool admits at most one CPU plus distinct GPUs: each device’s AOTI graph already saturates torch’s intra-op (OpenMP) thread pool, so two CPU workers would only oversubscribe the same cores.
Promoted parameters under hybrid. named_parameters() is a single master
map; mutating it in place is broadcast to every device copy before the next
dispatch, so the usual single-device idiom keeps working:
m.named_parameters().at("E").fill_(150e3); // reflected on every device next call
Error handling¶
Every exception that leaves forward / jvp / jacobian — on both the
synchronous and asynchronous paths — is a neml2::aoti::Exception (itself a
std::runtime_error) carrying a recoverable() flag. That flag is the contract
a downstream consumer branches on:
ConvergenceError(recoverable() == true) — the nonlinear solve diverged or hit its iteration cap. A time-stepping consumer can cut the step and retry.FatalError(recoverable() == false) — a shape / device mismatch, a missing input, a malformed artifact. A retry would fail identically, so it must hard-fail. Foreign errors (a torchc10::Error,std::bad_alloc, …) are normalized to this at the boundary, so a singlecatchcovers everything.
try
{
auto out = m.forward(inputs);
}
catch (const neml2::aoti::Exception & e)
{
if (e.recoverable()) { /* e.g. dt *= 0.5; retry */ }
else { throw; } // fatal: give up
}
Under asynchronous dispatch this stays well-defined even when several chunks run
at once. A failing chunk is caught inside its worker (a C++ exception escaping a
std::thread would call std::terminate), the scheduler is still drained so the
pool can never deadlock, and only then does the dispatcher decide what to throw:
one failure → it is re-thrown verbatim (its dynamic type, e.g.
ConvergenceError, is preserved);several at once → an
AggregateErrorcarrying them all. It reportsrecoverable()only if every sub-error is recoverable, so a lone fatal among otherwise-recoverable failures still forces a hard stop. The individual errors are available viaAggregateError::errors().
Either way the DispatchedModel and its scheduler are left clean and reusable
for the next call.
See also¶
AOTI packages — the per-device artifact + metadata layout.
Compilation pipeline — what
neml2-compiledoes internally.Compiled models — the end-to-end compile-and-load how-to.