NEML2 2.0.0
Loading...
Searching...
No Matches
WorkDispatcher.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <ATen/ops/linalg_inv.h>
28#include <ATen/ops/ones.h>
29#include <functional>
30#include <thread>
31#include <queue>
32#include <mutex>
33#include <condition_variable>
34
35#include <ATen/Parallel.h>
36
37#include "neml2/dispatchers/WorkGenerator.h"
38#include "neml2/dispatchers/WorkScheduler.h"
39#include "neml2/misc/assertions.h"
40#include "neml2/misc/types.h"
41#include "neml2/base/TracingInterface.h"
42
43// Pre-C++20 workaround for std::type_identity
44// https://en.cppreference.com/w/cpp/types/type_identity
45template <class T>
47{
48 using type = T;
49};
50
51namespace neml2
52{
102template <typename I,
103 typename O,
104 typename Of = typename std::vector<O>,
105 typename Ip = typename type_identity<I>::type,
106 typename Op = typename type_identity<O>::type>
108{
109public:
110 WorkDispatcher(WorkScheduler & scheduler, bool async, std::function<O(I &&, Device)> && do_work)
111 : TracingInterface(scheduler),
112 _scheduler(scheduler),
113 _devices(scheduler.devices()),
114 _async(async),
115 _do_work(std::move(do_work))
116 {
118 }
119
121 bool async,
122 std::function<O(I &&, Device)> && do_work,
123 std::function<O(std::vector<O> &&)> && reduce)
124 : TracingInterface(scheduler),
125 _scheduler(scheduler),
126 _devices(scheduler.devices()),
127 _async(async),
128 _do_work(std::move(do_work)),
129 _reduce(std::move(reduce))
130 {
132 }
133
135 bool async,
136 std::function<O(I &&, Device)> && do_work,
137 std::function<Of(std::vector<Op> &&)> && reduce,
138 std::function<I(Ip &&, Device)> && preprocess,
139 std::function<Op(O &&)> && postprocess)
140 : TracingInterface(scheduler),
141 _scheduler(scheduler),
142 _devices(scheduler.devices()),
143 _async(async),
144 _do_work(std::move(do_work)),
145 _reduce(std::move(reduce)),
146 _preprocess(std::move(preprocess)),
147 _postprocess(std::move(postprocess))
148 {
150 }
151
153 bool async,
154 std::function<O(I &&, Device)> && do_work,
155 std::function<Of(std::vector<Op> &&)> && reduce,
156 std::function<I(Ip &&, Device)> && preprocess,
157 std::function<Op(O &&)> && postprocess,
158 std::function<void(Device)> && thread_init)
159 : TracingInterface(scheduler),
160 _scheduler(scheduler),
161 _devices(scheduler.devices()),
162 _async(async),
163 _do_work(std::move(do_work)),
164 _reduce(std::move(reduce)),
165 _preprocess(std::move(preprocess)),
166 _postprocess(std::move(postprocess)),
167 _thread_init(std::move(thread_init))
168 {
169 if (_thread_init)
171 _async, "Custom thread initialization functor is only supported in asynchronous mode");
172
174 }
175
176 WorkDispatcher() = delete;
182
185
186protected:
188 void init_thread_pool();
189
191 void thread_pool_main(const Device &);
192
195
197 void stop_thread_pool();
198
200 void validate() const;
201
204
207
210
212 const std::vector<Device> _devices;
213
215 const bool _async;
216
218 std::function<O(I &&, Device)> _do_work;
219
221 std::function<Of(std::vector<Op> &&)> _reduce;
222
224 std::function<I(Ip &&, Device)> _preprocess;
225
227 std::function<Op(O &&)> _postprocess;
228
230 std::function<void(Device)> _thread_init;
231
233 std::vector<Op> _results;
234
237 std::mutex _qmutex;
239 std::condition_variable _thread_condition;
241 bool _stop = false;
245 // As it turns out it's undefined behavior to initialize this before the mutex and condition
246 // variable
247 std::vector<std::thread> _thread_pool;
249 std::unordered_map<Device, std::queue<std::function<void()>>> _tasks;
251};
252
254// Implementation
256template <typename I, typename O, typename Of, typename Ip, typename Op>
257void
259{
260 if (!_async)
261 return;
262
263#ifdef NEML2_HAS_JSON
264 if (event_tracing_enabled())
265 event_trace_writer().trace_duration_begin("thread pool", "WorkDispatcher");
266#endif
267
268 // Setup the task queue
269 for (const auto & device : _devices)
270 _tasks[device] = std::queue<std::function<void()>>();
271
272 auto nthread = _devices.size();
273 _thread_pool.reserve(nthread);
274 for (std::size_t i = 0; i < nthread; ++i)
275 {
276 // This is necessary to initialize the torch linear algebra library prior to threaded calls
277 // See: https://github.com/pytorch/pytorch/issues/90613
278 auto res = at::linalg_inv(at::ones({1, 1}));
279 _thread_pool.emplace_back([this, i] { thread_pool_main(_devices[i]); });
280 }
281
282 // Initialize the thread
283 if (_thread_init)
284 {
285 for (std::size_t i = 0; i < nthread; ++i)
286 {
287 auto device = _devices[i];
288 auto task = [this, device = device]() mutable
289 {
290#ifdef NEML2_HAS_JSON
291 if (event_tracing_enabled())
292 event_trace_writer().trace_duration_begin(
293 "thread init", "WorkDispatcher", {{"device", utils::stringify(device)}});
294#endif
295
296 _thread_init(device);
297 _scheduler.completed_work(device, 1);
298
299#ifdef NEML2_HAS_JSON
300 if (event_tracing_enabled())
301 event_trace_writer().trace_duration_end("thread init", "WorkDispatcher");
302#endif
303 };
304 _scheduler.dispatched_work(device, 1);
305 {
306 std::lock_guard<std::mutex> lock(_qmutex);
307 _tasks.at(device).push(task);
308 }
309 _thread_condition.notify_all();
310 }
311 _scheduler.wait_for_completion();
312 }
313}
314
315template <typename I, typename O, typename Of, typename Ip, typename Op>
316void
318{
319#ifdef NEML2_HAS_JSON
320 if (event_tracing_enabled())
321 event_trace_writer().trace_duration_begin(
322 "thread main", "WorkDispatcher", {{"device", utils::stringify(device)}});
323#endif
324
325 while (true)
326 {
327 std::function<void()> task;
328 {
329 std::unique_lock<std::mutex> lock(_qmutex);
330 _thread_condition.wait(lock, [this, &device] { return _stop || !_tasks.at(device).empty(); });
331 if (_stop && _tasks.at(device).empty())
332 break;
333 task = std::move(_tasks.at(device).front());
334 _tasks.at(device).pop();
335 }
336 task();
337 }
338
339#ifdef NEML2_HAS_JSON
340 if (event_tracing_enabled())
341 event_trace_writer().trace_duration_end("thread main", "WorkDispatcher");
342#endif
343}
344
345template <typename I, typename O, typename Of, typename Ip, typename Op>
346void
348{
349 if (!_async)
350 return;
351 {
352 std::unique_lock<std::mutex> lock(_qmutex);
353 _stop = true;
354 }
355 _thread_condition.notify_all();
356 for (auto & thread : _thread_pool)
357 thread.join();
358
359#ifdef NEML2_HAS_JSON
360 if (event_tracing_enabled())
361 event_trace_writer().trace_duration_end("thread pool", "WorkDispatcher");
362#endif
363}
364
365template <typename I, typename O, typename Of, typename Ip, typename Op>
366Of
368{
369#ifdef NEML2_HAS_JSON
370 if (event_tracing_enabled())
371 event_trace_writer().trace_duration_begin("run", "WorkDispatcher");
372#endif
373
374 if (_async)
375 {
376 auto result = run_async(generator);
377
378#ifdef NEML2_HAS_JSON
379 if (event_tracing_enabled())
380 event_trace_writer().trace_duration_end("run", "WorkDispatcher");
381#endif
382
383 return result;
384 }
385
386 auto result = run_sync(generator);
387
388#ifdef NEML2_HAS_JSON
389 if (event_tracing_enabled())
390 event_trace_writer().trace_duration_end("run", "WorkDispatcher");
391#endif
392
393 return result;
394}
395
396template <typename I, typename O, typename Of, typename Ip, typename Op>
397void
399{
400 if (!_do_work)
401 throw NEMLException("Do-work function is not set");
402
403 if constexpr (!std::is_same_v<I, Ip>)
404 if (!_preprocess)
405 throw NEMLException("Preprocess function is not set");
406
407 if constexpr (!std::is_same_v<O, Op>)
408 if (!_postprocess)
409 throw NEMLException("Postprocess function is not set");
410
411 if constexpr (!std::is_same_v<Of, std::vector<Op>>)
412 if (!_reduce)
413 throw NEMLException("Reduce function is not set");
414}
415
416template <typename I, typename O, typename Of, typename Ip, typename Op>
417Of
419{
420 validate();
421
422 Device device = kCPU;
423 std::size_t n = 0;
424 _results.clear();
425 while (generator.has_more())
426 {
427 _scheduler.schedule_work(device, n);
428 if (n <= 0)
429 throw NEMLException("Scheduler returned a batch size of " + std::to_string(n));
430 // Generate work
431 auto && [m, work] = generator.next(n);
432 // Preprocess
433 if (_preprocess)
434 work = _preprocess(std::move(work), device);
435 // Do work. Since there is no asynchronous execution, we do not notify the scheduler (this also
436 // avoids potential parallel communication inccured by the scheduler)
437 auto result = _do_work(std::move(work), device);
438 // Postprocess
439 if (_postprocess)
440 result = _postprocess(std::move(result));
441 _results.push_back(result);
442 }
443
444 if (_reduce)
445 return _reduce(std::move(_results));
446
447 if constexpr (std::is_same<Of, std::vector<Op>>::value)
448 return _results;
449
450 throw NEMLException("Internal error: unreachable code");
451}
452
453template <typename I, typename O, typename Of, typename Ip, typename Op>
454Of
456{
457 validate();
458
459 Device device = kCPU;
460 std::size_t n = 0;
461 _results.clear();
462
463 // Keep asking the scheduler for an available device
464 // - If the generator has no more work, we break out of the loop
465 // - If the scheduler schedules work, we dispatch the work and continue with the dispatching loop
466 while (generator.has_more())
467 {
468 _scheduler.schedule_work(device, n);
469 if (n <= 0)
470 throw NEMLException("Scheduler returned a batch size of " + std::to_string(n));
471 // Generate work
472 auto && [m, work] = generator.next(n);
473 // Reserve space for the result
474 _results.resize(_results.size() + 1);
475 auto i = _results.size() - 1;
476 // Create the task
477 auto task = [this, work = std::move(work), device = device, m = m, i = i]() mutable
478 {
479#ifdef NEML2_HAS_JSON
480 if (event_tracing_enabled())
481 event_trace_writer().trace_duration_begin(
482 "task", "WorkDispatcher", {{"device", utils::stringify(device)}, {"batch size", m}});
483#endif
484 // Preprocess
485 if (_preprocess)
486 work = _preprocess(std::move(work), device);
487 // Do work
488 auto result = _do_work(std::move(work), device);
489 // Postprocess
490 if (_postprocess)
491 result = _postprocess(std::move(result));
492 // Collect result
493 _results[i] = std::move(result);
494 // Tell the scheduler that we have completed m batches
495 _scheduler.completed_work(device, m);
496
497#ifdef NEML2_HAS_JSON
498 if (event_tracing_enabled())
499 event_trace_writer().trace_duration_end("task", "WorkDispatcher");
500#endif
501 };
502 // Tell the scheduler that we have dispatched m batches
503 _scheduler.dispatched_work(device, m);
504 // Enqueue the task
505 {
506 std::lock_guard<std::mutex> lock(_qmutex);
507 _tasks.at(device).push(task);
508 }
509 // Notify the thread pool
510 // Note: We notify_all instead of notify_one because we want the thread that's bind to the
511 // target device to pick up the task
512 _thread_condition.notify_all();
513 }
514
515 // Wait for all tasks to complete
516 _scheduler.wait_for_completion();
517
518 if (_reduce)
519 return _reduce(std::move(_results));
520
521 if constexpr (std::is_same<Of, std::vector<Op>>::value)
522 return _results;
523
524 throw NEMLException("Internal error: unreachable code");
525}
526} // namespace neml2
Definition errors.h:34
Definition TracingInterface.h:112
The work dispatcher who dispatches work to a worker and reduces the results.
Definition WorkDispatcher.h:108
const std::vector< Device > _devices
Device pool requested by the scheduler.
Definition WorkDispatcher.h:212
std::vector< std::thread > _thread_pool
Definition WorkDispatcher.h:247
const bool _async
Flag to enable asynchronous execution.
Definition WorkDispatcher.h:215
bool _stop
Flag to stop the thread pool.
Definition WorkDispatcher.h:241
std::vector< Op > _results
Results to be reduced.
Definition WorkDispatcher.h:233
std::function< void(Device)> _thread_init
Function to initialize the thread.
Definition WorkDispatcher.h:230
void init_thread_pool()
Initialize the thread pool.
Definition WorkDispatcher.h:258
WorkDispatcher & operator=(const WorkDispatcher &)=delete
Of run_async(WorkGenerator< Ip > &)
Run the dispatching loop asynchronously.
Definition WorkDispatcher.h:455
WorkDispatcher(const WorkDispatcher &)=delete
WorkScheduler & _scheduler
Reference to the work scheduler.
Definition WorkDispatcher.h:209
std::function< I(Ip &&, Device)> _preprocess
Function to preprocess the work.
Definition WorkDispatcher.h:224
~WorkDispatcher() override
Definition WorkDispatcher.h:181
std::unordered_map< Device, std::queue< std::function< void()> > > _tasks
Task queue for the thread pool.
Definition WorkDispatcher.h:249
std::mutex _qmutex
Definition WorkDispatcher.h:237
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess)
Definition WorkDispatcher.h:134
void stop_thread_pool()
Stop the thread pool.
Definition WorkDispatcher.h:347
void thread_pool_main(const Device &)
Thread pool main function.
Definition WorkDispatcher.h:317
WorkDispatcher & operator=(WorkDispatcher &&)=delete
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< Of(std::vector< Op > &&)> &&reduce, std::function< I(Ip &&, Device)> &&preprocess, std::function< Op(O &&)> &&postprocess, std::function< void(Device)> &&thread_init)
Definition WorkDispatcher.h:152
std::function< Op(O &&)> _postprocess
Function to postprocess the result.
Definition WorkDispatcher.h:227
void validate() const
Helper function to validate that the dispatcher is properly configured.
Definition WorkDispatcher.h:398
WorkDispatcher(WorkDispatcher &&)=delete
Of run_sync(WorkGenerator< Ip > &)
Run the dispatching loop synchronously.
Definition WorkDispatcher.h:418
std::function< O(I &&, Device)> _do_work
Function to perform the work and return the result.
Definition WorkDispatcher.h:218
Of run(WorkGenerator< Ip > &)
Run the dispatching loop (calls run_sync or run_async based on the async flag)
Definition WorkDispatcher.h:367
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work)
Definition WorkDispatcher.h:110
std::function< Of(std::vector< Op > &&)> _reduce
Function to reduce the results.
Definition WorkDispatcher.h:221
WorkDispatcher(WorkScheduler &scheduler, bool async, std::function< O(I &&, Device)> &&do_work, std::function< O(std::vector< O > &&)> &&reduce)
Definition WorkDispatcher.h:120
bool should_unlock_thread()
Should unlock thread.
std::condition_variable _thread_condition
Condition variable for the tasks queue.
Definition WorkDispatcher.h:239
Definition WorkGenerator.h:34
std::pair< std::size_t, T > next(std::size_t n)
Generate the next n batches of work.
Definition WorkGenerator.h:57
virtual bool has_more() const =0
Whether the generator has more work to generate.
Scheduler for work dispatching.
Definition WorkScheduler.h:48
std::string stringify(const T &t)
Definition string_utils.h:70
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:63
void neml_assert_dbg(bool assertion, Args &&... args)
Definition assertions.h:60
constexpr auto kCPU
Definition types.h:53
Definition WorkDispatcher.h:47
T type
Definition WorkDispatcher.h:48