32#include <condition_variable>
34#include "neml2/dispatchers/WorkGenerator.h"
35#include "neml2/dispatchers/WorkScheduler.h"
36#include "neml2/misc/assertions.h"
37#include "neml2/misc/types.h"
39#include "neml2/tensors/R2.h"
102 typename Of =
typename std::vector<O>,
119 std::function<O(I &&,
Device)> && do_work,
120 std::function<O(std::vector<O> &&)> && reduce)
132 std::function<O(I &&,
Device)> && do_work,
133 std::function<Of(std::vector<Op> &&)> && reduce,
134 std::function<I(Ip &&,
Device)> && preprocess,
135 std::function<Op(O &&)> && postprocess)
149 std::function<O(I &&,
Device)> && do_work,
150 std::function<Of(std::vector<Op> &&)> && reduce,
151 std::function<I(Ip &&,
Device)> && preprocess,
152 std::function<Op(O &&)> && postprocess,
153 std::function<
void(
Device)> && thread_init)
165 _async,
"Custom thread initialization functor is only supported in asynchronous mode");
215 std::function<Of(std::vector<Op> &&)>
_reduce;
250template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
258 for (
const auto & device :
_devices)
259 _tasks[device] = std::queue<std::function<void()>>();
263 for (std::size_t i = 0; i < nthread; ++i)
274 for (std::size_t i = 0; i < nthread; ++i)
277 auto task = [
this, device = device]()
mutable
284 std::lock_guard<std::mutex> lock(
_qmutex);
285 _tasks.at(device).push(task);
293template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
299 std::function<void()> task;
301 std::unique_lock<std::mutex> lock(
_qmutex);
305 task = std::move(
_tasks.at(device).front());
312template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
319 std::unique_lock<std::mutex> lock(
_qmutex);
327template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
336template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
343 if constexpr (!std::is_same_v<I, Ip>)
347 if constexpr (!std::is_same_v<O, Op>)
351 if constexpr (!std::is_same_v<Of, std::vector<Op>>)
356template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
369 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
371 auto && [m, work] = generator.
next(n);
377 auto result =
_do_work(std::move(work), device);
387 if constexpr (std::is_same<Of, std::vector<Op>>::value)
393template <
typename I,
typename O,
typename Of,
typename Ip,
typename Op>
410 throw NEMLException(
"Scheduler returned a batch size of " + std::to_string(n));
412 auto && [m, work] = generator.
next(n);
417 auto task = [
this, work = std::move(work), device = device, m = m, i = i]()
mutable
423 auto result =
_do_work(std::move(work), device);
436 std::lock_guard<std::mutex> lock(
_qmutex);
437 _tasks.at(device).push(task);
451 if constexpr (std::is_same<Of, std::vector<Op>>::value)
Derived inverse() const
Inversion.
Definition R2Base.cxx:214
static R2 identity(const TensorOptions &options=default_tensor_options())
Definition R2Base.cxx:166
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:150
const std::vector< Device > _devices
Device pool requested by the scheduler.
Definition WorkDispatcher.h:206
std::vector< std::thread > _thread_pool
Definition WorkDispatcher.h:241
const bool _async
Flag to enable asynchronous execution.
Definition WorkDispatcher.h:209
bool _stop
Flag to stop the thread pool.
Definition WorkDispatcher.h:235
std::vector< Op > _results
Results to be reduced.
Definition WorkDispatcher.h:227
std::function< void(Device)> _thread_init
Function to initialize the thread.
Definition WorkDispatcher.h:224
void init_thread_pool()
Initialize the thread pool.
Definition WorkDispatcher.h:252
~WorkDispatcher()
Definition WorkDispatcher.h:175
WorkDispatcher & operator=(const WorkDispatcher &)=delete
Of run_async(WorkGenerator< Ip > &)
Run the dispatching loop asynchronously.
Definition WorkDispatcher.h:395
WorkDispatcher(const WorkDispatcher &)=delete
WorkScheduler & _scheduler
Reference to the work scheduler.
Definition WorkDispatcher.h:203
std::function< I(Ip &&, Device)> _preprocess
Function to preprocess the work.
Definition WorkDispatcher.h:218
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
Task queue for the thread pool.
Definition WorkDispatcher.h:243
std::mutex _qmutex
Definition WorkDispatcher.h:231
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
Definition WorkDispatcher.h:130
void stop_thread_pool()
Stop the thread pool.
Definition WorkDispatcher.h:314
void thread_pool_main(const Device &)
Thread pool main function.
Definition WorkDispatcher.h:295
WorkDispatcher & operator=(WorkDispatcher &&)=delete
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
Definition WorkDispatcher.h:147
std::function< Op(O &&)> _postprocess
Function to postprocess the result.
Definition WorkDispatcher.h:221
void validate() const
Helper function to validate that the dispatcher is properly configured.
Definition WorkDispatcher.h:338
WorkDispatcher(WorkDispatcher &&)=delete
Of run_sync(WorkGenerator< Ip > &)
Run the dispatching loop synchronously.
Definition WorkDispatcher.h:358
std::function< O(I &&, Device)> _do_work
Function to perform the work and return the result.
Definition WorkDispatcher.h:212
Of run(WorkGenerator< Ip > &)
Run the dispatching loop (calls run_sync or run_async based on the async flag)
Definition WorkDispatcher.h:329
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
Definition WorkDispatcher.h:108
std::function< Of(std::vector< Op > &&)> _reduce
Function to reduce the results.
Definition WorkDispatcher.h:215
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
Definition WorkDispatcher.h:117
bool should_unlock_thread()
Should unlock thread.
std::condition_variable _thread_condition
Condition variable for the tasks queue.
Definition WorkDispatcher.h:233
Definition WorkGenerator.h:34
std::pair< std::size_t, T > next(std::size_t n)
Generate the next n batches of work.
Definition WorkGenerator.h:57
virtual bool has_more() const =0
Whether the generator has more work to generate.
Scheduler for work dispatching.
Definition WorkScheduler.h:47
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:66
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
constexpr auto kCPU
Definition types.h:56
Definition WorkDispatcher.h:45
T type
Definition WorkDispatcher.h:46