NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
WorkDispatcher.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <functional>
28#include <future>
29#include <thread>
30#include <queue>
31#include <mutex>
32#include <condition_variable>
33
34#include "neml2/dispatchers/WorkGenerator.h"
35#include "neml2/dispatchers/WorkScheduler.h"
36#include "neml2/misc/assertions.h"
37#include "neml2/misc/types.h"
38
39#include "neml2/tensors/R2.h"
40
41// Pre-C++20 workaround for std::type_identity
42// https://en.cppreference.com/w/cpp/types/type_identity
43template <class T>
45{
46 using type = T;
47};
48
49namespace neml2
50{
100template <typename I,
101 typename O,
102 typename Of = typename std::vector<O>,
103 typename Ip = typename type_identity<I>::type,
104 typename Op = typename type_identity<O>::type>
106{
107public:
108 WorkDispatcher(WorkScheduler & scheduler, bool async, std::function<O(I &&, Device)> && do_work)
109 : _scheduler(scheduler),
110 _devices(scheduler.devices()),
111 _async(async),
112 _do_work(std::move(do_work))
113 {
115 }
116
118 bool async,
119 std::function<O(I &&, Device)> && do_work,
120 std::function<O(std::vector<O> &&)> && reduce)
121 : _scheduler(scheduler),
122 _devices(scheduler.devices()),
123 _async(async),
124 _do_work(std::move(do_work)),
125 _reduce(std::move(reduce))
126 {
128 }
129
131 bool async,
132 std::function<O(I &&, Device)> && do_work,
133 std::function<Of(std::vector<Op> &&)> && reduce,
134 std::function<I(Ip &&, Device)> && preprocess,
135 std::function<Op(O &&)> && postprocess)
136 : _scheduler(scheduler),
137 _devices(scheduler.devices()),
138 _async(async),
139 _do_work(std::move(do_work)),
140 _reduce(std::move(reduce)),
141 _preprocess(std::move(preprocess)),
142 _postprocess(std::move(postprocess))
143 {
145 }
146
148 bool async,
149 std::function<O(I &&, Device)> && do_work,
150 std::function<Of(std::vector<Op> &&)> && reduce,
151 std::function<I(Ip &&, Device)> && preprocess,
152 std::function<Op(O &&)> && postprocess,
153 std::function<void(Device)> && thread_init)
154 : _scheduler(scheduler),
155 _devices(scheduler.devices()),
156 _async(async),
157 _do_work(std::move(do_work)),
158 _reduce(std::move(reduce)),
159 _preprocess(std::move(preprocess)),
160 _postprocess(std::move(postprocess)),
161 _thread_init(std::move(thread_init))
162 {
163 if (_thread_init)
165 _async, "Custom thread initialization functor is only supported in asynchronous mode");
166
168 }
169
170 WorkDispatcher() = delete;
176
179
180protected:
182 void init_thread_pool();
183
185 void thread_pool_main(const Device &);
186
189
191 void stop_thread_pool();
192
194 void validate() const;
195
198
201
204
206 const std::vector<Device> _devices;
207
209 const bool _async;
210
212 std::function<O(I &&, Device)> _do_work;
213
215 std::function<Of(std::vector<Op> &&)> _reduce;
216
218 std::function<I(Ip &&, Device)> _preprocess;
219
221 std::function<Op(O &&)> _postprocess;
222
224 std::function<void(Device)> _thread_init;
225
227 std::vector<Op> _results;
228
231 std::mutex _qmutex;
233 std::condition_variable _thread_condition;
235 bool _stop = false;
239 // As it turns out it's undefined behavior to initialize this before the mutex and condition
240 // variable
241 std::vector<std::thread> _thread_pool;
243 std::unordered_map<Device, std::queue<std::function<void()>>> _tasks;
245};
246
248// Implementation
250template <typename I, typename O, typename Of, typename Ip, typename Op>
251void
253{
254 if (!_async)
255 return;
256
257 // Setup the task queue
258 for (const auto & device : _devices)
259 _tasks[device] = std::queue<std::function<void()>>();
260
261 auto nthread = _devices.size();
262 _thread_pool.reserve(nthread);
263 for (std::size_t i = 0; i < nthread; ++i)
264 {
265 // This is necessary to initialize the torch linear algebra library prior to threaded calls
266 // See: https://github.com/pytorch/pytorch/issues/90613
267 auto res = R2::identity().to(_devices[i]).inverse();
268 _thread_pool.emplace_back([this, i] { thread_pool_main(_devices[i]); });
269 }
270
271 // Initialize the thread
272 if (_thread_init)
273 {
274 for (std::size_t i = 0; i < nthread; ++i)
275 {
276 auto device = _devices[i];
277 auto task = [this, device = device]() mutable
278 {
279 _thread_init(device);
280 _scheduler.completed_work(device, 1);
281 };
282 _scheduler.dispatched_work(device, 1);
283 {
284 std::lock_guard<std::mutex> lock(_qmutex);
285 _tasks.at(device).push(task);
286 }
287 _thread_condition.notify_all();
288 }
289 _scheduler.wait_for_completion();
290 }
291}
292
293template <typename I, typename O, typename Of, typename Ip, typename Op>
294void
296{
297 while (true)
298 {
299 std::function<void()> task;
300 {
301 std::unique_lock<std::mutex> lock(_qmutex);
302 _thread_condition.wait(lock, [this, &device] { return _stop || !_tasks.at(device).empty(); });
303 if (_stop && _tasks.at(device).empty())
304 break;
305 task = std::move(_tasks.at(device).front());
306 _tasks.at(device).pop();
307 }
308 task();
309 }
310}
311
312template <typename I, typename O, typename Of, typename Ip, typename Op>
313void
315{
316 if (!_async)
317 return;
318 {
319 std::unique_lock<std::mutex> lock(_qmutex);
320 _stop = true;
321 }
322 _thread_condition.notify_all();
323 for (auto & thread : _thread_pool)
324 thread.join();
325}
326
327template <typename I, typename O, typename Of, typename Ip, typename Op>
328Of
330{
331 if (_async)
332 return run_async(generator);
333 return run_sync(generator);
334}
335
336template <typename I, typename O, typename Of, typename Ip, typename Op>
337void
339{
340 if (!_do_work)
341 throw NEMLException("Do-work function is not set");
342
343 if constexpr (!std::is_same_v<I, Ip>)
344 if (!_preprocess)
345 throw NEMLException("Preprocess function is not set");
346
347 if constexpr (!std::is_same_v<O, Op>)
348 if (!_postprocess)
349 throw NEMLException("Postprocess function is not set");
350
351 if constexpr (!std::is_same_v<Of, std::vector<Op>>)
352 if (!_reduce)
353 throw NEMLException("Reduce function is not set");
354}
355
356template <typename I, typename O, typename Of, typename Ip, typename Op>
357Of
359{
360 validate();
361
362 Device device = kCPU;
363 std::size_t n = 0;
364 _results.clear();
365 while (generator.has_more())
366 {
367 _scheduler.schedule_work(device, n);
368 if (n <= 0)
369 throw NEMLException("Scheduler returned a batch size of " + std::to_string(n));
370 // Generate work
371 auto && [m, work] = generator.next(n);
372 // Preprocess
373 if (_preprocess)
374 work = _preprocess(std::move(work), device);
375 // Do work. Since there is no asynchronous execution, we do not notify the scheduler (this also
376 // avoids potential parallel communication inccured by the scheduler)
377 auto result = _do_work(std::move(work), device);
378 // Postprocess
379 if (_postprocess)
380 result = _postprocess(std::move(result));
381 _results.push_back(result);
382 }
383
384 if (_reduce)
385 return _reduce(std::move(_results));
386
387 if constexpr (std::is_same<Of, std::vector<Op>>::value)
388 return _results;
389
390 throw NEMLException("Internal error: unreachable code");
391}
392
393template <typename I, typename O, typename Of, typename Ip, typename Op>
394Of
396{
397 validate();
398
399 Device device = kCPU;
400 std::size_t n = 0;
401 _results.clear();
402
403 // Keep asking the scheduler for an available device
404 // - If the generator has no more work, we break out of the loop
405 // - If the scheduler schedules work, we dispatch the work and continue with the dispatching loop
406 while (generator.has_more())
407 {
408 _scheduler.schedule_work(device, n);
409 if (n <= 0)
410 throw NEMLException("Scheduler returned a batch size of " + std::to_string(n));
411 // Generate work
412 auto && [m, work] = generator.next(n);
413 // Reserve space for the result
414 _results.resize(_results.size() + 1);
415 auto i = _results.size() - 1;
416 // Create the task
417 auto task = [this, work = std::move(work), device = device, m = m, i = i]() mutable
418 {
419 // Preprocess
420 if (_preprocess)
421 work = _preprocess(std::move(work), device);
422 // Do work
423 auto result = _do_work(std::move(work), device);
424 // Postprocess
425 if (_postprocess)
426 result = _postprocess(std::move(result));
427 // Collect result
428 _results[i] = std::move(result);
429 // Tell the scheduler that we have completed m batches
430 _scheduler.completed_work(device, m);
431 };
432 // Tell the scheduler that we have dispatched m batches
433 _scheduler.dispatched_work(device, m);
434 // Enqueue the task
435 {
436 std::lock_guard<std::mutex> lock(_qmutex);
437 _tasks.at(device).push(task);
438 }
439 // Notify the thread pool
440 // Note: We notify_all instead of notify_one because we want the thread that's bind to the
441 // target device to pick up the task
442 _thread_condition.notify_all();
443 }
444
445 // Wait for all tasks to complete
446 _scheduler.wait_for_completion();
447
448 if (_reduce)
449 return _reduce(std::move(_results));
450
451 if constexpr (std::is_same<Of, std::vector<Op>>::value)
452 return _results;
453
454 throw NEMLException("Internal error: unreachable code");
455}
456} // namespace neml2
Definition errors.h:34
Derived inverse() const
Inversion.
Definition R2Base.cxx:214
static R2 identity(const TensorOptions &options=default_tensor_options())
Definition R2Base.cxx:166
Derived to(const TensorOptions &options) const
Change tensor options.
Definition TensorBaseImpl.h:150
const std::vector< Device > _devices
Device pool requested by the scheduler.
Definition WorkDispatcher.h:206
std::vector< std::thread > _thread_pool
Definition WorkDispatcher.h:241
const bool _async
Flag to enable asynchronous execution.
Definition WorkDispatcher.h:209
bool _stop
Flag to stop the thread pool.
Definition WorkDispatcher.h:235
std::vector< Op > _results
Results to be reduced.
Definition WorkDispatcher.h:227
std::function< void(Device)> _thread_init
Function to initialize the thread.
Definition WorkDispatcher.h:224
void init_thread_pool()
Initialize the thread pool.
Definition WorkDispatcher.h:252
~WorkDispatcher()
Definition WorkDispatcher.h:175
WorkDispatcher & operator=(const WorkDispatcher &)=delete
Of run_async(WorkGenerator< Ip > &)
Run the dispatching loop asynchronously.
Definition WorkDispatcher.h:395
WorkDispatcher(const WorkDispatcher &)=delete
WorkScheduler & _scheduler
Reference to the work scheduler.
Definition WorkDispatcher.h:203
std::function< I(Ip &&, Device)> _preprocess
Function to preprocess the work.
Definition WorkDispatcher.h:218
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
Task queue for the thread pool.
Definition WorkDispatcher.h:243
std::mutex _qmutex
Definition WorkDispatcher.h:231
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
Definition WorkDispatcher.h:130
void stop_thread_pool()
Stop the thread pool.
Definition WorkDispatcher.h:314
void thread_pool_main(const Device &)
Thread pool main function.
Definition WorkDispatcher.h:295
WorkDispatcher & operator=(WorkDispatcher &&)=delete
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
Definition WorkDispatcher.h:147
std::function< Op(O &&)> _postprocess
Function to postprocess the result.
Definition WorkDispatcher.h:221
void validate() const
Helper function to validate that the dispatcher is properly configured.
Definition WorkDispatcher.h:338
WorkDispatcher(WorkDispatcher &&)=delete
Of run_sync(WorkGenerator< Ip > &)
Run the dispatching loop synchronously.
Definition WorkDispatcher.h:358
std::function< O(I &&, Device)> _do_work
Function to perform the work and return the result.
Definition WorkDispatcher.h:212
Of run(WorkGenerator< Ip > &)
Run the dispatching loop (calls run_sync or run_async based on the async flag)
Definition WorkDispatcher.h:329
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
Definition WorkDispatcher.h:108
std::function< Of(std::vector< Op > &&)> _reduce
Function to reduce the results.
Definition WorkDispatcher.h:215
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
Definition WorkDispatcher.h:117
bool should_unlock_thread()
Should unlock thread.
std::condition_variable _thread_condition
Condition variable for the tasks queue.
Definition WorkDispatcher.h:233
Definition WorkGenerator.h:34
std::pair< std::size_t, T > next(std::size_t n)
Generate the next n batches of work.
Definition WorkGenerator.h:57
virtual bool has_more() const =0
Whether the generator has more work to generate.
Scheduler for work dispatching.
Definition WorkScheduler.h:47
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:66
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
constexpr auto kCPU
Definition types.h:56
Definition WorkDispatcher.h:45
T type
Definition WorkDispatcher.h:46