31#include <condition_variable>
33#include <ATen/Parallel.h>
35#include "neml2/dispatchers/WorkGenerator.h"
36#include "neml2/dispatchers/WorkScheduler.h"
37#include "neml2/misc/assertions.h"
38#include "neml2/misc/types.h"
39#include "neml2/base/TracingInterface.h"
41#include "neml2/tensors/R2.h"
104 typename Of =
typename std::vector<O>,
122 std::function<O(I &&,
Device)> && do_work,
123 std::function<O(std::vector<O> &&)> && reduce)
136 std::function<O(I &&,
Device)> && do_work,
137 std::function<Of(std::vector<Op> &&)> && reduce,
138 std::function<I(Ip &&,
Device)> && preprocess,
139 std::function<Op(O &&)> && postprocess)
154 std::function<O(I &&,
Device)> && do_work,
155 std::function<Of(std::vector<Op> &&)> && reduce,
156 std::function<I(Ip &&,
Device)> && preprocess,
157 std::function<Op(O &&)> && postprocess,
158 std::function<
void(
Device)> && thread_init)
171 _async,
"Custom thread initialization functor is only supported in asynchronous mode");
221 std::function<Of(std::vector<Op> &&)>
_reduce;
256template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
269 for (
const auto & device :
_devices)
270 _tasks[device] = std::queue<std::function<void()>>();
274 for (std::size_t i = 0; i < nthread; ++i)
285 for (std::size_t i = 0; i < nthread; ++i)
288 auto task = [
this, device = device]()
mutable
306 std::lock_guard<std::mutex> lock(
_qmutex);
307 _tasks.at(device).push(task);
315template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
327 std::function<void()> task;
329 std::unique_lock<std::mutex> lock(
_qmutex);
333 task = std::move(
_tasks.at(device).front());
345template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
352 std::unique_lock<std::mutex> lock(
_qmutex);
365template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
396template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
403 if constexpr (!std::is_same_v<I, Ip>)
407 if constexpr (!std::is_same_v<O, Op>)
411 if constexpr (!std::is_same_v<Of, std::vector<Op>>)
416template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
429 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
431 auto && [m, work] = generator.
next(n);
437 auto result =
_do_work(std::move(work), device);
447 if constexpr (std::is_same<Of, std::vector<Op>>::value)
453template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
470 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
472 auto && [m, work] = generator.
next(n);
477 auto task = [
this, work = std::move(work), device = device, m = m, i = i]()
mutable
482 "task",
"WorkDispatcher", {{
"device",
utils::stringify(device)}, {
"batch size", m}});
488 auto result =
_do_work(std::move(work), device);
506 std::lock_guard<std::mutex> lock(
_qmutex);
507 _tasks.at(device).push(task);
521 if constexpr (std::is_same<Of, std::vector<Op>>::value)
Derived inverse() const
Inversion.
Definition R2Base.cxx:240
static R2 identity(const TensorOptions &options=default_tensor_options())
Definition R2Base.cxx:167
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:154
TracingInterface(std::string)
Definition TracingInterface.cxx:151
TraceWriter & event_trace_writer() const
Get the event trace writer.
Definition TracingInterface.cxx:186
bool event_tracing_enabled() const
Definition TracingInterface.h:125
const std::vector< Device > _devices
Device pool requested by the scheduler.
Definition WorkDispatcher.h:212
std::vector< std::thread > _thread_pool
Definition WorkDispatcher.h:247
const bool _async
Flag to enable asynchronous execution.
Definition WorkDispatcher.h:215
bool _stop
Flag to stop the thread pool.
Definition WorkDispatcher.h:241
std::vector< Op > _results
Results to be reduced.
Definition WorkDispatcher.h:233
std::function< void(Device)> _thread_init
Function to initialize the thread.
Definition WorkDispatcher.h:230
void init_thread_pool()
Initialize the thread pool.
Definition WorkDispatcher.h:258
WorkDispatcher & operator=(const WorkDispatcher &)=delete
Of run_async(WorkGenerator< Ip > &)
Run the dispatching loop asynchronously.
Definition WorkDispatcher.h:455
WorkDispatcher(const WorkDispatcher &)=delete
WorkScheduler & _scheduler
Reference to the work scheduler.
Definition WorkDispatcher.h:209
std::function< I(Ip &&, Device)> _preprocess
Function to preprocess the work.
Definition WorkDispatcher.h:224
~WorkDispatcher() override
Definition WorkDispatcher.h:181
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
Task queue for the thread pool.
Definition WorkDispatcher.h:249
std::mutex _qmutex
Definition WorkDispatcher.h:237
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
Definition WorkDispatcher.h:134
void stop_thread_pool()
Stop the thread pool.
Definition WorkDispatcher.h:347
void thread_pool_main(const Device &)
Thread pool main function.
Definition WorkDispatcher.h:317
WorkDispatcher & operator=(WorkDispatcher &&)=delete
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
Definition WorkDispatcher.h:152
std::function< Op(O &&)> _postprocess
Function to postprocess the result.
Definition WorkDispatcher.h:227
void validate() const
Helper function to validate that the dispatcher is properly configured.
Definition WorkDispatcher.h:398
WorkDispatcher(WorkDispatcher &&)=delete
Of run_sync(WorkGenerator< Ip > &)
Run the dispatching loop synchronously.
Definition WorkDispatcher.h:418
std::function< O(I &&, Device)> _do_work
Function to perform the work and return the result.
Definition WorkDispatcher.h:218
Of run(WorkGenerator< Ip > &)
Run the dispatching loop (calls run_sync or run_async based on the async flag)
Definition WorkDispatcher.h:367
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
Definition WorkDispatcher.h:110
std::function< Of(std::vector< Op > &&)> _reduce
Function to reduce the results.
Definition WorkDispatcher.h:221
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
Definition WorkDispatcher.h:120
bool should_unlock_thread()
Should unlock thread.
std::condition_variable _thread_condition
Condition variable for the tasks queue.
Definition WorkDispatcher.h:239
Definition WorkGenerator.h:34
std::pair< std::size_t, T > next(std::size_t n)
Generate the next n batches of work.
Definition WorkGenerator.h:57
virtual bool has_more() const =0
Whether the generator has more work to generate.
Scheduler for work dispatching.
Definition WorkScheduler.h:48
std::string stringify(const T &t)
Definition string_utils.h:70
Definition DiagnosticsInterface.cxx:29
c10::Device Device
Definition types.h:63
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
constexpr auto kCPU
Definition types.h:53
Definition WorkDispatcher.h:47
T type
Definition WorkDispatcher.h:48