NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
WorkDispatcher< I, O, Of, Ip, Op > Class Template Reference

The work dispatcher who dispatches work to a worker and reduces the results. More...

Detailed Description

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
class neml2::WorkDispatcher< I, O, Of, Ip, Op >

The work dispatcher who dispatches work to a worker and reduces the results.

Warning
The dispatcher is designed to be thread safe, but we are currently seeing some issues when the dispatcher interacts with torch::jit::tracer. We do not recommend using the dispatcher with torch::jit::tracer until the issue is resolved.

The work dispatcher coordinates with WorkGenerator and WorkScheduler to dispatch work. The work is generated/loaded by the WorkGenerator; the dispatch is scheduled by a WorkScheduler; the dispatching loop is managed by the WorkDispatcher.

The dispatcher also takes care of preprocessing, postprocessing, and reducing the work. In general, each work dispatch involves four steps:

  1. Work generation: The work generator generates the next n batches of work.
  2. Preprocessing: The dispatcher preprocesses the work.
  3. Do work: The worker performs the work.
  4. Postprocessing: The dispatcher postprocesses the result.

Once all the work has been completed and results have been collected, the dispatcher reduces the results to obtain the final result.

Notes on threading: The dispatcher can run in synchronous or asynchronous mode.

  • In synchronous mode, the dispatcher runs in the main thread and dispatches work sequentially. No additional threads are created.
  • In asynchronous mode, the dispatcher creates a thread pool where each thread is continuously monitoring the task queue. The main thread adds work to the task queue, and the threads in the pool pick up the task and execute it. The main thread waits for all the work to complete before reducing the results.

Notes on coordination with the scheduler: The dispatcher communicates with the scheduler to schedule work and to notify the scheduler when work has been dispatched and completed.

  • In synchronous mode, the dispatcher does not notify the scheduler when work has been dispatched (since the work is dispatched sequentially).
  • In asynchronous mode, the dispatcher notifies the scheduler when work has been dispatched (i.e. a task added to the task queue). When the task is completed, the worker notifies the scheduler about work completion.

Notes on thread-device binding: Currently, the implementation assumes that each thread in the thread pool is binded to one device. Based on this assumption, dispatching work to a device is equivalent to dispatching work to a thread, which greatly simplifies the communication between the threads and the task queue. This assumption could be relaxed in the future based on profiling evidence showing that multiple threads dispatching work to the same device has certain advantage.

Template Parameters
IInput type of the preprocessed work (generated by the generator)
OOutput type of the result returned by the worker
OfOutput type of the final result (after reduction)
IpInput type of the work before preprocessing
OpOutput type of the result after postprocessing

#include <WorkDispatcher.h>

Public Member Functions

 WorkDispatcher (WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
 
 WorkDispatcher (WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
 
 WorkDispatcher (WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
 
 WorkDispatcher (WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
 
 WorkDispatcher ()=delete
 
 WorkDispatcher (WorkDispatcher &&)=delete
 
 WorkDispatcher (const WorkDispatcher &)=delete
 
WorkDispatcheroperator= (WorkDispatcher &&)=delete
 
WorkDispatcheroperator= (const WorkDispatcher &)=delete
 
 ~WorkDispatcher ()
 
Of run (WorkGenerator< Ip > &)
 Run the dispatching loop (calls run_sync or run_async based on the async flag)
 

Protected Member Functions

void init_thread_pool ()
 Initialize the thread pool.
 
void thread_pool_main (const Device &)
 Thread pool main function.
 
bool should_unlock_thread ()
 Should unlock thread.
 
void stop_thread_pool ()
 Stop the thread pool.
 
void validate () const
 Helper function to validate that the dispatcher is properly configured.
 
Of run_sync (WorkGenerator< Ip > &)
 Run the dispatching loop synchronously.
 
Of run_async (WorkGenerator< Ip > &)
 Run the dispatching loop asynchronously.
 

Protected Attributes

WorkScheduler_scheduler
 Reference to the work scheduler.
 
const std::vector< Device_devices
 Device pool requested by the scheduler.
 
const bool _async
 Flag to enable asynchronous execution.
 
std::function< O(I &&, Device)> _do_work
 Function to perform the work and return the result.
 
std::function< Of(std::vector< Op > &&)> _reduce
 Function to reduce the results.
 
std::function< I(Ip &&, Device)> _preprocess
 Function to preprocess the work.
 
std::function< Op(O &&)> _postprocess
 Function to postprocess the result.
 
std::function< void(Device)> _thread_init
 Function to initialize the thread.
 
std::vector< Op > _results
 Results to be reduced.
 
std::mutex _qmutex
 
std::condition_variable _thread_condition
 Condition variable for the tasks queue.
 
bool _stop = false
 Flag to stop the thread pool.
 
std::vector< std::thread > _thread_pool
 
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
 Task queue for the thread pool.
 

Constructor & Destructor Documentation

◆ WorkDispatcher() [1/7]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher ( WorkScheduler & scheduler,
bool async,
std::function< O(I &&, Device)> && do_work )
inline

◆ WorkDispatcher() [2/7]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher ( WorkScheduler & scheduler,
bool async,
std::function< O(I &&, Device)> && do_work,
std::function< O(std::vector< O > &&)> && reduce )
inline

◆ WorkDispatcher() [3/7]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher ( WorkScheduler & scheduler,
bool async,
std::function< O(I &&, Device)> && do_work,
std::function< Of(std::vector< Op > &&)> && reduce,
std::function< I(Ip &&, Device)> && preprocess,
std::function< Op(O &&)> && postprocess )
inline

◆ WorkDispatcher() [4/7]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher ( WorkScheduler & scheduler,
bool async,
std::function< O(I &&, Device)> && do_work,
std::function< Of(std::vector< Op > &&)> && reduce,
std::function< I(Ip &&, Device)> && preprocess,
std::function< Op(O &&)> && postprocess,
std::function< void(Device)> && thread_init )
inline

◆ WorkDispatcher() [5/7]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher ( )
delete

◆ WorkDispatcher() [6/7]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher ( WorkDispatcher< I, O, Of, Ip, Op > && )
delete

◆ WorkDispatcher() [7/7]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher ( const WorkDispatcher< I, O, Of, Ip, Op > & )
delete

◆ ~WorkDispatcher()

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
~WorkDispatcher ( )
inline

Member Function Documentation

◆ init_thread_pool()

template<typename I, typename O, typename Of, typename Ip, typename Op>
void init_thread_pool ( )
protected

Initialize the thread pool.

◆ operator=() [1/2]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher & operator= ( const WorkDispatcher< I, O, Of, Ip, Op > & )
delete

◆ operator=() [2/2]

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkDispatcher & operator= ( WorkDispatcher< I, O, Of, Ip, Op > && )
delete

◆ run()

template<typename I, typename O, typename Of, typename Ip, typename Op>
Of run ( WorkGenerator< Ip > & generator)

Run the dispatching loop (calls run_sync or run_async based on the async flag)

◆ run_async()

template<typename I, typename O, typename Of, typename Ip, typename Op>
Of run_async ( WorkGenerator< Ip > & generator)
protected

Run the dispatching loop asynchronously.

◆ run_sync()

template<typename I, typename O, typename Of, typename Ip, typename Op>
Of run_sync ( WorkGenerator< Ip > & generator)
protected

Run the dispatching loop synchronously.

◆ should_unlock_thread()

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
bool should_unlock_thread ( )
protected

Should unlock thread.

◆ stop_thread_pool()

template<typename I, typename O, typename Of, typename Ip, typename Op>
void stop_thread_pool ( )
protected

Stop the thread pool.

◆ thread_pool_main()

template<typename I, typename O, typename Of, typename Ip, typename Op>
void thread_pool_main ( const Device & device)
protected

Thread pool main function.

◆ validate()

template<typename I, typename O, typename Of, typename Ip, typename Op>
void validate ( ) const
protected

Helper function to validate that the dispatcher is properly configured.

Member Data Documentation

◆ _async

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
const bool _async
protected

Flag to enable asynchronous execution.

◆ _devices

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
const std::vector<Device> _devices
protected

Device pool requested by the scheduler.

◆ _do_work

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::function<O(I &&, Device)> _do_work
protected

Function to perform the work and return the result.

◆ _postprocess

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::function<Op(O &&)> _postprocess
protected

Function to postprocess the result.

◆ _preprocess

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::function<I(Ip &&, Device)> _preprocess
protected

Function to preprocess the work.

◆ _qmutex

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::mutex _qmutex
protected

Mutex for the thread pool to pick up task from the task queue

◆ _reduce

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::function<Of(std::vector<Op> &&)> _reduce
protected

Function to reduce the results.

◆ _results

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::vector<Op> _results
protected

Results to be reduced.

◆ _scheduler

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
WorkScheduler& _scheduler
protected

Reference to the work scheduler.

◆ _stop

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
bool _stop = false
protected

Flag to stop the thread pool.

◆ _tasks

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::unordered_map<Device, std::queue<std::function<void()> > > _tasks
protected

Task queue for the thread pool.

◆ _thread_condition

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::condition_variable _thread_condition
protected

Condition variable for the tasks queue.

◆ _thread_init

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::function<void(Device)> _thread_init
protected

Function to initialize the thread.

◆ _thread_pool

template<typename I, typename O, typename Of = typename std::vector<O>, typename Ip = typename type_identity<I>::type, typename Op = typename type_identity<O>::type>
std::vector<std::thread> _thread_pool
protected

Thread pool for asynchronous execution TODO: We are currently assuming each thread is responsible for one device. This may not be true/optimal in the future.