27#include <ATen/ops/linalg_inv.h>
28#include <ATen/ops/ones.h>
33#include <condition_variable>
35#include <ATen/Parallel.h>
37#include "neml2/dispatchers/WorkGenerator.h"
38#include "neml2/dispatchers/WorkScheduler.h"
39#include "neml2/misc/assertions.h"
40#include "neml2/misc/types.h"
41#include "neml2/base/TracingInterface.h"
104 typename Of =
typename std::vector<O>,
122 std::function<O(I &&,
Device)> && do_work,
123 std::function<O(std::vector<O> &&)> && reduce)
136 std::function<O(I &&,
Device)> && do_work,
137 std::function<Of(std::vector<Op> &&)> && reduce,
138 std::function<I(Ip &&,
Device)> && preprocess,
139 std::function<Op(O &&)> && postprocess)
154 std::function<O(I &&,
Device)> && do_work,
155 std::function<Of(std::vector<Op> &&)> && reduce,
156 std::function<I(Ip &&,
Device)> && preprocess,
157 std::function<Op(O &&)> && postprocess,
158 std::function<
void(
Device)> && thread_init)
171 _async,
"Custom thread initialization functor is only supported in asynchronous mode");
221 std::function<Of(std::vector<Op> &&)>
_reduce;
256template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
264 if (event_tracing_enabled())
265 event_trace_writer().trace_duration_begin(
"thread pool",
"WorkDispatcher");
269 for (
const auto & device : _devices)
270 _tasks[device] = std::queue<std::function<void()>>();
272 auto nthread = _devices.size();
273 _thread_pool.reserve(nthread);
274 for (std::size_t i = 0; i < nthread; ++i)
278 auto res = at::linalg_inv(at::ones({1, 1}));
279 _thread_pool.emplace_back([
this, i] { thread_pool_main(_devices[i]); });
285 for (std::size_t i = 0; i < nthread; ++i)
287 auto device = _devices[i];
288 auto task = [
this, device = device]()
mutable
291 if (event_tracing_enabled())
292 event_trace_writer().trace_duration_begin(
296 _thread_init(device);
297 _scheduler.completed_work(device, 1);
300 if (event_tracing_enabled())
301 event_trace_writer().trace_duration_end(
"thread init",
"WorkDispatcher");
304 _scheduler.dispatched_work(device, 1);
306 std::lock_guard<std::mutex> lock(_qmutex);
307 _tasks.at(device).push(task);
309 _thread_condition.notify_all();
311 _scheduler.wait_for_completion();
315template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
320 if (event_tracing_enabled())
321 event_trace_writer().trace_duration_begin(
327 std::function<void()> task;
329 std::unique_lock<std::mutex> lock(_qmutex);
330 _thread_condition.wait(lock, [
this, &device] {
return _stop || !_tasks.at(device).empty(); });
331 if (_stop && _tasks.at(device).empty())
333 task = std::move(_tasks.at(device).front());
334 _tasks.at(device).pop();
340 if (event_tracing_enabled())
341 event_trace_writer().trace_duration_end(
"thread main",
"WorkDispatcher");
345template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
352 std::unique_lock<std::mutex> lock(_qmutex);
355 _thread_condition.notify_all();
356 for (
auto & thread : _thread_pool)
360 if (event_tracing_enabled())
361 event_trace_writer().trace_duration_end(
"thread pool",
"WorkDispatcher");
365template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
370 if (event_tracing_enabled())
371 event_trace_writer().trace_duration_begin(
"run",
"WorkDispatcher");
376 auto result = run_async(generator);
379 if (event_tracing_enabled())
380 event_trace_writer().trace_duration_end(
"run",
"WorkDispatcher");
386 auto result = run_sync(generator);
389 if (event_tracing_enabled())
390 event_trace_writer().trace_duration_end(
"run",
"WorkDispatcher");
396template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
403 if constexpr (!std::is_same_v<I, Ip>)
407 if constexpr (!std::is_same_v<O, Op>)
411 if constexpr (!std::is_same_v<Of, std::vector<Op>>)
416template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
427 _scheduler.schedule_work(device, n);
429 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
431 auto && [m, work] = generator.
next(n);
434 work = _preprocess(std::move(work), device);
437 auto result = _do_work(std::move(work), device);
440 result = _postprocess(std::move(result));
441 _results.push_back(result);
445 return _reduce(std::move(_results));
447 if constexpr (std::is_same<Of, std::vector<Op>>::value)
453template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
468 _scheduler.schedule_work(device, n);
470 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
472 auto && [m, work] = generator.
next(n);
474 _results.resize(_results.size() + 1);
475 auto i = _results.size() - 1;
477 auto task = [
this, work = std::move(work), device = device, m = m, i = i]()
mutable
480 if (event_tracing_enabled())
481 event_trace_writer().trace_duration_begin(
482 "task",
"WorkDispatcher", {{
"device",
utils::stringify(device)}, {
"batch size", m}});
486 work = _preprocess(std::move(work), device);
488 auto result = _do_work(std::move(work), device);
491 result = _postprocess(std::move(result));
493 _results[i] = std::move(result);
495 _scheduler.completed_work(device, m);
498 if (event_tracing_enabled())
499 event_trace_writer().trace_duration_end(
"task",
"WorkDispatcher");
503 _scheduler.dispatched_work(device, m);
506 std::lock_guard<std::mutex> lock(_qmutex);
507 _tasks.at(device).push(task);
512 _thread_condition.notify_all();
516 _scheduler.wait_for_completion();
519 return _reduce(std::move(_results));
521 if constexpr (std::is_same<Of, std::vector<Op>>::value)
Definition TracingInterface.h:112
The work dispatcher who dispatches work to a worker and reduces the results.
Definition WorkDispatcher.h:108
const std::vector< Device > _devices
Device pool requested by the scheduler.
Definition WorkDispatcher.h:212
std::vector< std::thread > _thread_pool
Definition WorkDispatcher.h:247
const bool _async
Flag to enable asynchronous execution.
Definition WorkDispatcher.h:215
bool _stop
Flag to stop the thread pool.
Definition WorkDispatcher.h:241
std::vector< Op > _results
Results to be reduced.
Definition WorkDispatcher.h:233
std::function< void(Device)> _thread_init
Function to initialize the thread.
Definition WorkDispatcher.h:230
void init_thread_pool()
Initialize the thread pool.
Definition WorkDispatcher.h:258
WorkDispatcher & operator=(const WorkDispatcher &)=delete
Of run_async(WorkGenerator< Ip > &)
Run the dispatching loop asynchronously.
Definition WorkDispatcher.h:455
WorkDispatcher(const WorkDispatcher &)=delete
WorkScheduler & _scheduler
Reference to the work scheduler.
Definition WorkDispatcher.h:209
std::function< I(Ip &&, Device)> _preprocess
Function to preprocess the work.
Definition WorkDispatcher.h:224
~WorkDispatcher() override
Definition WorkDispatcher.h:181
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
Task queue for the thread pool.
Definition WorkDispatcher.h:249
std::mutex _qmutex
Definition WorkDispatcher.h:237
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
Definition WorkDispatcher.h:134
void stop_thread_pool()
Stop the thread pool.
Definition WorkDispatcher.h:347
void thread_pool_main(const Device &)
Thread pool main function.
Definition WorkDispatcher.h:317
WorkDispatcher & operator=(WorkDispatcher &&)=delete
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
Definition WorkDispatcher.h:152
std::function< Op(O &&)> _postprocess
Function to postprocess the result.
Definition WorkDispatcher.h:227
void validate() const
Helper function to validate that the dispatcher is properly configured.
Definition WorkDispatcher.h:398
WorkDispatcher(WorkDispatcher &&)=delete
Of run_sync(WorkGenerator< Ip > &)
Run the dispatching loop synchronously.
Definition WorkDispatcher.h:418
std::function< O(I &&, Device)> _do_work
Function to perform the work and return the result.
Definition WorkDispatcher.h:218
Of run(WorkGenerator< Ip > &)
Run the dispatching loop (calls run_sync or run_async based on the async flag)
Definition WorkDispatcher.h:367
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
Definition WorkDispatcher.h:110
std::function< Of(std::vector< Op > &&)> _reduce
Function to reduce the results.
Definition WorkDispatcher.h:221
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
Definition WorkDispatcher.h:120
bool should_unlock_thread()
Should unlock thread.
std::condition_variable _thread_condition
Condition variable for the tasks queue.
Definition WorkDispatcher.h:239
Definition WorkGenerator.h:34
std::pair< std::size_t, T > next(std::size_t n)
Generate the next n batches of work.
Definition WorkGenerator.h:57
virtual bool has_more() const =0
Whether the generator has more work to generate.
Scheduler for work dispatching.
Definition WorkScheduler.h:48
std::string stringify(const T &t)
Definition string_utils.h:70
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:63
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
constexpr auto kCPU
Definition types.h:53
Definition WorkDispatcher.h:47
T type
Definition WorkDispatcher.h:48