27#include <ATen/ops/linalg_inv.h>
28#include <ATen/ops/ones.h>
33#include <condition_variable>
35#include <ATen/Parallel.h>
37#include "neml2/config.h"
38#include "neml2/dispatchers/WorkGenerator.h"
39#include "neml2/dispatchers/WorkScheduler.h"
40#include "neml2/misc/assertions.h"
41#include "neml2/misc/types.h"
42#include "neml2/base/TracingInterface.h"
105 typename Of =
typename std::vector<O>,
123 std::function<O(I &&,
Device)> && do_work,
124 std::function<O(std::vector<O> &&)> && reduce)
137 std::function<O(I &&,
Device)> && do_work,
138 std::function<Of(std::vector<Op> &&)> && reduce,
139 std::function<I(Ip &&,
Device)> && preprocess,
140 std::function<Op(O &&)> && postprocess)
155 std::function<O(I &&,
Device)> && do_work,
156 std::function<Of(std::vector<Op> &&)> && reduce,
157 std::function<I(Ip &&,
Device)> && preprocess,
158 std::function<Op(O &&)> && postprocess,
159 std::function<
void(
Device)> && thread_init)
172 _async,
"Custom thread initialization functor is only supported in asynchronous mode");
222 std::function<Of(std::vector<Op> &&)>
_reduce;
257template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
265 if (event_tracing_enabled())
266 event_trace_writer().trace_duration_begin(
"thread pool",
"WorkDispatcher");
270 for (
const auto & device : _devices)
271 _tasks[device] = std::queue<std::function<void()>>();
273 auto nthread = _devices.size();
274 _thread_pool.reserve(nthread);
275 for (std::size_t i = 0; i < nthread; ++i)
279 auto res = at::linalg_inv(at::ones({1, 1}));
280 _thread_pool.emplace_back([
this, i] { thread_pool_main(_devices[i]); });
286 for (std::size_t i = 0; i < nthread; ++i)
288 auto device = _devices[i];
289 auto task = [
this, device = device]()
mutable
292 if (event_tracing_enabled())
293 event_trace_writer().trace_duration_begin(
297 _thread_init(device);
298 _scheduler.completed_work(device, 1);
301 if (event_tracing_enabled())
302 event_trace_writer().trace_duration_end(
"thread init",
"WorkDispatcher");
305 _scheduler.dispatched_work(device, 1);
307 std::lock_guard<std::mutex> lock(_qmutex);
308 _tasks.at(device).push(task);
310 _thread_condition.notify_all();
312 _scheduler.wait_for_completion();
316template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
321 if (event_tracing_enabled())
322 event_trace_writer().trace_duration_begin(
328 std::function<void()> task;
330 std::unique_lock<std::mutex> lock(_qmutex);
331 _thread_condition.wait(lock, [
this, &device] {
return _stop || !_tasks.at(device).empty(); });
332 if (_stop && _tasks.at(device).empty())
334 task = std::move(_tasks.at(device).front());
335 _tasks.at(device).pop();
341 if (event_tracing_enabled())
342 event_trace_writer().trace_duration_end(
"thread main",
"WorkDispatcher");
346template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
353 std::unique_lock<std::mutex> lock(_qmutex);
356 _thread_condition.notify_all();
357 for (
auto & thread : _thread_pool)
361 if (event_tracing_enabled())
362 event_trace_writer().trace_duration_end(
"thread pool",
"WorkDispatcher");
366template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
371 if (event_tracing_enabled())
372 event_trace_writer().trace_duration_begin(
"run",
"WorkDispatcher");
377 auto result = run_async(generator);
380 if (event_tracing_enabled())
381 event_trace_writer().trace_duration_end(
"run",
"WorkDispatcher");
387 auto result = run_sync(generator);
390 if (event_tracing_enabled())
391 event_trace_writer().trace_duration_end(
"run",
"WorkDispatcher");
397template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
404 if constexpr (!std::is_same_v<I, Ip>)
408 if constexpr (!std::is_same_v<O, Op>)
412 if constexpr (!std::is_same_v<Of, std::vector<Op>>)
417template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
428 _scheduler.schedule_work(device, n);
430 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
432 auto && [m, work] = generator.
next(n);
435 work = _preprocess(std::move(work), device);
438 auto result = _do_work(std::move(work), device);
441 result = _postprocess(std::move(result));
442 _results.push_back(result);
446 return _reduce(std::move(_results));
448 if constexpr (std::is_same<Of, std::vector<Op>>::value)
454template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
469 _scheduler.schedule_work(device, n);
471 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
473 auto && [m, work] = generator.
next(n);
475 _results.resize(_results.size() + 1);
476 auto i = _results.size() - 1;
478 auto task = [
this, work = std::move(work), device = device, m = m, i = i]()
mutable
481 if (event_tracing_enabled())
482 event_trace_writer().trace_duration_begin(
483 "task",
"WorkDispatcher", {{
"device",
utils::stringify(device)}, {
"batch size", m}});
487 work = _preprocess(std::move(work), device);
489 auto result = _do_work(std::move(work), device);
492 result = _postprocess(std::move(result));
494 _results[i] = std::move(result);
496 _scheduler.completed_work(device, m);
499 if (event_tracing_enabled())
500 event_trace_writer().trace_duration_end(
"task",
"WorkDispatcher");
504 _scheduler.dispatched_work(device, m);
507 std::lock_guard<std::mutex> lock(_qmutex);
508 _tasks.at(device).push(task);
513 _thread_condition.notify_all();
517 _scheduler.wait_for_completion();
520 return _reduce(std::move(_results));
522 if constexpr (std::is_same<Of, std::vector<Op>>::value)
Definition TracingInterface.h:114
The work dispatcher who dispatches work to a worker and reduces the results.
Definition WorkDispatcher.h:109
const std::vector< Device > _devices
Device pool requested by the scheduler.
Definition WorkDispatcher.h:213
std::vector< std::thread > _thread_pool
Definition WorkDispatcher.h:248
const bool _async
Flag to enable asynchronous execution.
Definition WorkDispatcher.h:216
bool _stop
Flag to stop the thread pool.
Definition WorkDispatcher.h:242
std::vector< Op > _results
Results to be reduced.
Definition WorkDispatcher.h:234
std::function< void(Device)> _thread_init
Function to initialize the thread.
Definition WorkDispatcher.h:231
void init_thread_pool()
Initialize the thread pool.
Definition WorkDispatcher.h:259
WorkDispatcher & operator=(const WorkDispatcher &)=delete
Of run_async(WorkGenerator< Ip > &)
Run the dispatching loop asynchronously.
Definition WorkDispatcher.h:456
WorkDispatcher(const WorkDispatcher &)=delete
WorkScheduler & _scheduler
Reference to the work scheduler.
Definition WorkDispatcher.h:210
std::function< I(Ip &&, Device)> _preprocess
Function to preprocess the work.
Definition WorkDispatcher.h:225
~WorkDispatcher() override
Definition WorkDispatcher.h:182
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
Task queue for the thread pool.
Definition WorkDispatcher.h:250
std::mutex _qmutex
Definition WorkDispatcher.h:238
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
Definition WorkDispatcher.h:135
void stop_thread_pool()
Stop the thread pool.
Definition WorkDispatcher.h:348
void thread_pool_main(const Device &)
Thread pool main function.
Definition WorkDispatcher.h:318
WorkDispatcher & operator=(WorkDispatcher &&)=delete
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
Definition WorkDispatcher.h:153
std::function< Op(O &&)> _postprocess
Function to postprocess the result.
Definition WorkDispatcher.h:228
void validate() const
Helper function to validate that the dispatcher is properly configured.
Definition WorkDispatcher.h:399
WorkDispatcher(WorkDispatcher &&)=delete
Of run_sync(WorkGenerator< Ip > &)
Run the dispatching loop synchronously.
Definition WorkDispatcher.h:419
std::function< O(I &&, Device)> _do_work
Function to perform the work and return the result.
Definition WorkDispatcher.h:219
Of run(WorkGenerator< Ip > &)
Run the dispatching loop (calls run_sync or run_async based on the async flag)
Definition WorkDispatcher.h:368
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
Definition WorkDispatcher.h:111
std::function< Of(std::vector< Op > &&)> _reduce
Function to reduce the results.
Definition WorkDispatcher.h:222
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
Definition WorkDispatcher.h:121
bool should_unlock_thread()
Should unlock thread.
std::condition_variable _thread_condition
Condition variable for the tasks queue.
Definition WorkDispatcher.h:240
Definition WorkGenerator.h:34
std::pair< std::size_t, T > next(std::size_t n)
Generate the next n batches of work.
Definition WorkGenerator.h:57
virtual bool has_more() const =0
Whether the generator has more work to generate.
Scheduler for work dispatching.
Definition WorkScheduler.h:48
std::string stringify(const T &t)
Definition string_utils.h:70
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:66
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
constexpr auto kCPU
Definition types.h:56
Definition WorkDispatcher.h:48
T type
Definition WorkDispatcher.h:49