NEML2 2.0.0
Loading...
Searching...
No Matches
WorkDispatcher.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <ATen/ops/linalg_inv.h>
28#include <ATen/ops/ones.h>
29#include <functional>
30#include <thread>
31#include <queue>
32#include <mutex>
33#include <condition_variable>
34
35#include <ATen/Parallel.h>
36
37#include "neml2/config.h"
38#include "neml2/dispatchers/WorkGenerator.h"
39#include "neml2/dispatchers/WorkScheduler.h"
40#include "neml2/misc/assertions.h"
41#include "neml2/misc/types.h"
42#include "neml2/base/TracingInterface.h"
43
44// Pre-C++20 workaround for std::type_identity
45// https://en.cppreference.com/w/cpp/types/type_identity
46template <class T>
48{
49 using type = T;
50};
51
52namespace neml2
53{
103template <typename I,
104 typename O,
105 typename Of = typename std::vector<O>,
106 typename Ip = typename type_identity<I>::type,
107 typename Op = typename type_identity<O>::type>
109{
110public:
111 WorkDispatcher(WorkScheduler & scheduler, bool async, std::function<O(I &&, Device)> && do_work)
112 : TracingInterface(scheduler),
113 _scheduler(scheduler),
114 _devices(scheduler.devices()),
115 _async(async),
116 _do_work(std::move(do_work))
117 {
119 }
120
122 bool async,
123 std::function<O(I &&, Device)> && do_work,
124 std::function<O(std::vector<O> &&)> && reduce)
125 : TracingInterface(scheduler),
126 _scheduler(scheduler),
127 _devices(scheduler.devices()),
128 _async(async),
129 _do_work(std::move(do_work)),
130 _reduce(std::move(reduce))
131 {
133 }
134
136 bool async,
137 std::function<O(I &&, Device)> && do_work,
138 std::function<Of(std::vector<Op> &&)> && reduce,
139 std::function<I(Ip &&, Device)> && preprocess,
140 std::function<Op(O &&)> && postprocess)
141 : TracingInterface(scheduler),
142 _scheduler(scheduler),
143 _devices(scheduler.devices()),
144 _async(async),
145 _do_work(std::move(do_work)),
146 _reduce(std::move(reduce)),
147 _preprocess(std::move(preprocess)),
148 _postprocess(std::move(postprocess))
149 {
151 }
152
154 bool async,
155 std::function<O(I &&, Device)> && do_work,
156 std::function<Of(std::vector<Op> &&)> && reduce,
157 std::function<I(Ip &&, Device)> && preprocess,
158 std::function<Op(O &&)> && postprocess,
159 std::function<void(Device)> && thread_init)
160 : TracingInterface(scheduler),
161 _scheduler(scheduler),
162 _devices(scheduler.devices()),
163 _async(async),
164 _do_work(std::move(do_work)),
165 _reduce(std::move(reduce)),
166 _preprocess(std::move(preprocess)),
167 _postprocess(std::move(postprocess)),
168 _thread_init(std::move(thread_init))
169 {
170 if (_thread_init)
172 _async, "Custom thread initialization functor is only supported in asynchronous mode");
173
175 }
176
177 WorkDispatcher() = delete;
183
186
187protected:
189 void init_thread_pool();
190
192 void thread_pool_main(const Device &);
193
196
198 void stop_thread_pool();
199
201 void validate() const;
202
205
208
211
213 const std::vector<Device> _devices;
214
216 const bool _async;
217
219 std::function<O(I &&, Device)> _do_work;
220
222 std::function<Of(std::vector<Op> &&)> _reduce;
223
225 std::function<I(Ip &&, Device)> _preprocess;
226
228 std::function<Op(O &&)> _postprocess;
229
231 std::function<void(Device)> _thread_init;
232
234 std::vector<Op> _results;
235
238 std::mutex _qmutex;
240 std::condition_variable _thread_condition;
242 bool _stop = false;
246 // As it turns out it's undefined behavior to initialize this before the mutex and condition
247 // variable
248 std::vector<std::thread> _thread_pool;
250 std::unordered_map<Device, std::queue<std::function<void()>>> _tasks;
252};
253
255// Implementation
257template <typename I, typename O, typename Of, typename Ip, typename Op>
258void
260{
261 if (!_async)
262 return;
263
264#ifdef NEML2_JSON
265 if (event_tracing_enabled())
266 event_trace_writer().trace_duration_begin("thread pool", "WorkDispatcher");
267#endif
268
269 // Setup the task queue
270 for (const auto & device : _devices)
271 _tasks[device] = std::queue<std::function<void()>>();
272
273 auto nthread = _devices.size();
274 _thread_pool.reserve(nthread);
275 for (std::size_t i = 0; i < nthread; ++i)
276 {
277 // This is necessary to initialize the torch linear algebra library prior to threaded calls
278 // See: https://github.com/pytorch/pytorch/issues/90613
279 auto res = at::linalg_inv(at::ones({1, 1}));
280 _thread_pool.emplace_back([this, i] { thread_pool_main(_devices[i]); });
281 }
282
283 // Initialize the thread
284 if (_thread_init)
285 {
286 for (std::size_t i = 0; i < nthread; ++i)
287 {
288 auto device = _devices[i];
289 auto task = [this, device = device]() mutable
290 {
291#ifdef NEML2_JSON
292 if (event_tracing_enabled())
293 event_trace_writer().trace_duration_begin(
294 "thread init", "WorkDispatcher", {{"device", utils::stringify(device)}});
295#endif
296
297 _thread_init(device);
298 _scheduler.completed_work(device, 1);
299
300#ifdef NEML2_JSON
301 if (event_tracing_enabled())
302 event_trace_writer().trace_duration_end("thread init", "WorkDispatcher");
303#endif
304 };
305 _scheduler.dispatched_work(device, 1);
306 {
307 std::lock_guard<std::mutex> lock(_qmutex);
308 _tasks.at(device).push(task);
309 }
310 _thread_condition.notify_all();
311 }
312 _scheduler.wait_for_completion();
313 }
314}
315
316template <typename I, typename O, typename Of, typename Ip, typename Op>
317void
319{
320#ifdef NEML2_JSON
321 if (event_tracing_enabled())
322 event_trace_writer().trace_duration_begin(
323 "thread main", "WorkDispatcher", {{"device", utils::stringify(device)}});
324#endif
325
326 while (true)
327 {
328 std::function<void()> task;
329 {
330 std::unique_lock<std::mutex> lock(_qmutex);
331 _thread_condition.wait(lock, [this, &device] { return _stop || !_tasks.at(device).empty(); });
332 if (_stop && _tasks.at(device).empty())
333 break;
334 task = std::move(_tasks.at(device).front());
335 _tasks.at(device).pop();
336 }
337 task();
338 }
339
340#ifdef NEML2_JSON
341 if (event_tracing_enabled())
342 event_trace_writer().trace_duration_end("thread main", "WorkDispatcher");
343#endif
344}
345
346template <typename I, typename O, typename Of, typename Ip, typename Op>
347void
349{
350 if (!_async)
351 return;
352 {
353 std::unique_lock<std::mutex> lock(_qmutex);
354 _stop = true;
355 }
356 _thread_condition.notify_all();
357 for (auto & thread : _thread_pool)
358 thread.join();
359
360#ifdef NEML2_JSON
361 if (event_tracing_enabled())
362 event_trace_writer().trace_duration_end("thread pool", "WorkDispatcher");
363#endif
364}
365
366template <typename I, typename O, typename Of, typename Ip, typename Op>
367Of
369{
370#ifdef NEML2_JSON
371 if (event_tracing_enabled())
372 event_trace_writer().trace_duration_begin("run", "WorkDispatcher");
373#endif
374
375 if (_async)
376 {
377 auto result = run_async(generator);
378
379#ifdef NEML2_JSON
380 if (event_tracing_enabled())
381 event_trace_writer().trace_duration_end("run", "WorkDispatcher");
382#endif
383
384 return result;
385 }
386
387 auto result = run_sync(generator);
388
389#ifdef NEML2_JSON
390 if (event_tracing_enabled())
391 event_trace_writer().trace_duration_end("run", "WorkDispatcher");
392#endif
393
394 return result;
395}
396
397template <typename I, typename O, typename Of, typename Ip, typename Op>
398void
400{
401 if (!_do_work)
402 throw NEMLException("Do-work function is not set");
403
404 if constexpr (!std::is_same_v<I, Ip>)
405 if (!_preprocess)
406 throw NEMLException("Preprocess function is not set");
407
408 if constexpr (!std::is_same_v<O, Op>)
409 if (!_postprocess)
410 throw NEMLException("Postprocess function is not set");
411
412 if constexpr (!std::is_same_v<Of, std::vector<Op>>)
413 if (!_reduce)
414 throw NEMLException("Reduce function is not set");
415}
416
417template <typename I, typename O, typename Of, typename Ip, typename Op>
418Of
420{
421 validate();
422
423 Device device = kCPU;
424 std::size_t n = 0;
425 _results.clear();
426 while (generator.has_more())
427 {
428 _scheduler.schedule_work(device, n);
429 if (n <= 0)
430 throw NEMLException("Scheduler returned a batch size of " + std::to_string(n));
431 // Generate work
432 auto && [m, work] = generator.next(n);
433 // Preprocess
434 if (_preprocess)
435 work = _preprocess(std::move(work), device);
436 // Do work. Since there is no asynchronous execution, we do not notify the scheduler (this also
437 // avoids potential parallel communication inccured by the scheduler)
438 auto result = _do_work(std::move(work), device);
439 // Postprocess
440 if (_postprocess)
441 result = _postprocess(std::move(result));
442 _results.push_back(result);
443 }
444
445 if (_reduce)
446 return _reduce(std::move(_results));
447
448 if constexpr (std::is_same<Of, std::vector<Op>>::value)
449 return _results;
450
451 throw NEMLException("Internal error: unreachable code");
452}
453
454template <typename I, typename O, typename Of, typename Ip, typename Op>
455Of
457{
458 validate();
459
460 Device device = kCPU;
461 std::size_t n = 0;
462 _results.clear();
463
464 // Keep asking the scheduler for an available device
465 // - If the generator has no more work, we break out of the loop
466 // - If the scheduler schedules work, we dispatch the work and continue with the dispatching loop
467 while (generator.has_more())
468 {
469 _scheduler.schedule_work(device, n);
470 if (n <= 0)
471 throw NEMLException("Scheduler returned a batch size of " + std::to_string(n));
472 // Generate work
473 auto && [m, work] = generator.next(n);
474 // Reserve space for the result
475 _results.resize(_results.size() + 1);
476 auto i = _results.size() - 1;
477 // Create the task
478 auto task = [this, work = std::move(work), device = device, m = m, i = i]() mutable
479 {
480#ifdef NEML2_JSON
481 if (event_tracing_enabled())
482 event_trace_writer().trace_duration_begin(
483 "task", "WorkDispatcher", {{"device", utils::stringify(device)}, {"batch size", m}});
484#endif
485 // Preprocess
486 if (_preprocess)
487 work = _preprocess(std::move(work), device);
488 // Do work
489 auto result = _do_work(std::move(work), device);
490 // Postprocess
491 if (_postprocess)
492 result = _postprocess(std::move(result));
493 // Collect result
494 _results[i] = std::move(result);
495 // Tell the scheduler that we have completed m batches
496 _scheduler.completed_work(device, m);
497
498#ifdef NEML2_JSON
499 if (event_tracing_enabled())
500 event_trace_writer().trace_duration_end("task", "WorkDispatcher");
501#endif
502 };
503 // Tell the scheduler that we have dispatched m batches
504 _scheduler.dispatched_work(device, m);
505 // Enqueue the task
506 {
507 std::lock_guard<std::mutex> lock(_qmutex);
508 _tasks.at(device).push(task);
509 }
510 // Notify the thread pool
511 // Note: We notify_all instead of notify_one because we want the thread that's bind to the
512 // target device to pick up the task
513 _thread_condition.notify_all();
514 }
515
516 // Wait for all tasks to complete
517 _scheduler.wait_for_completion();
518
519 if (_reduce)
520 return _reduce(std::move(_results));
521
522 if constexpr (std::is_same<Of, std::vector<Op>>::value)
523 return _results;
524
525 throw NEMLException("Internal error: unreachable code");
526}
527} // namespace neml2
Definition errors.h:34
Definition TracingInterface.h:114
The work dispatcher who dispatches work to a worker and reduces the results.
Definition WorkDispatcher.h:109
const std::vector< Device > _devices
Device pool requested by the scheduler.
Definition WorkDispatcher.h:213
std::vector< std::thread > _thread_pool
Definition WorkDispatcher.h:248
const bool _async
Flag to enable asynchronous execution.
Definition WorkDispatcher.h:216
bool _stop
Flag to stop the thread pool.
Definition WorkDispatcher.h:242
std::vector< Op > _results
Results to be reduced.
Definition WorkDispatcher.h:234
std::function< void(Device)> _thread_init
Function to initialize the thread.
Definition WorkDispatcher.h:231
void init_thread_pool()
Initialize the thread pool.
Definition WorkDispatcher.h:259
WorkDispatcher & operator=(const WorkDispatcher &)=delete
Of run_async(WorkGenerator< Ip > &)
Run the dispatching loop asynchronously.
Definition WorkDispatcher.h:456
WorkDispatcher(const WorkDispatcher &)=delete
WorkScheduler & _scheduler
Reference to the work scheduler.
Definition WorkDispatcher.h:210
std::function< I(Ip &&, Device)> _preprocess
Function to preprocess the work.
Definition WorkDispatcher.h:225
~WorkDispatcher() override
Definition WorkDispatcher.h:182
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
Task queue for the thread pool.
Definition WorkDispatcher.h:250
std::mutex _qmutex
Definition WorkDispatcher.h:238
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
Definition WorkDispatcher.h:135
void stop_thread_pool()
Stop the thread pool.
Definition WorkDispatcher.h:348
void thread_pool_main(const Device &)
Thread pool main function.
Definition WorkDispatcher.h:318
WorkDispatcher & operator=(WorkDispatcher &&)=delete
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
Definition WorkDispatcher.h:153
std::function< Op(O &&)> _postprocess
Function to postprocess the result.
Definition WorkDispatcher.h:228
void validate() const
Helper function to validate that the dispatcher is properly configured.
Definition WorkDispatcher.h:399
WorkDispatcher(WorkDispatcher &&)=delete
Of run_sync(WorkGenerator< Ip > &)
Run the dispatching loop synchronously.
Definition WorkDispatcher.h:419
std::function< O(I &&, Device)> _do_work
Function to perform the work and return the result.
Definition WorkDispatcher.h:219
Of run(WorkGenerator< Ip > &)
Run the dispatching loop (calls run_sync or run_async based on the async flag)
Definition WorkDispatcher.h:368
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
Definition WorkDispatcher.h:111
std::function< Of(std::vector< Op > &&)> _reduce
Function to reduce the results.
Definition WorkDispatcher.h:222
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
Definition WorkDispatcher.h:121
bool should_unlock_thread()
Should unlock thread.
std::condition_variable _thread_condition
Condition variable for the tasks queue.
Definition WorkDispatcher.h:240
Definition WorkGenerator.h:34
std::pair< std::size_t, T > next(std::size_t n)
Generate the next n batches of work.
Definition WorkGenerator.h:57
virtual bool has_more() const =0
Whether the generator has more work to generate.
Scheduler for work dispatching.
Definition WorkScheduler.h:48
std::string stringify(const T &t)
Definition string_utils.h:70
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:66
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
constexpr auto kCPU
Definition types.h:56
Definition WorkDispatcher.h:48
T type
Definition WorkDispatcher.h:49