Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/dispatchers/WorkScheduler.h"
26 : #include "neml2/base/TracingInterface.h"
27 :
28 : namespace neml2
29 : {
30 :
31 : OptionSet
32 17 : WorkScheduler::expected_options()
33 : {
34 17 : OptionSet options = NEML2Object::expected_options();
35 17 : options += TracingInterface::expected_options();
36 17 : options.section() = "Schedulers";
37 :
38 17 : return options;
39 0 : }
40 :
41 11 : WorkScheduler::WorkScheduler(const OptionSet & options)
42 : : NEML2Object(options),
43 11 : TracingInterface(options)
44 : {
45 11 : }
46 :
47 : #ifdef NEML2_HAS_JSON
48 : static json
49 0 : to_json(const Device & device, const std::size_t & batch_size)
50 : {
51 0 : json j;
52 0 : j["device"] = utils::stringify(device);
53 0 : j["batch size"] = batch_size;
54 0 : return j;
55 0 : }
56 : #endif
57 :
58 : void
59 46 : WorkScheduler::schedule_work(Device & device, std::size_t & batch_size)
60 : {
61 46 : std::unique_lock<std::mutex> lock(_mutex);
62 :
63 : #ifdef NEML2_HAS_JSON
64 46 : if (event_tracing_enabled())
65 0 : event_trace_writer().trace_duration_begin("schedule work", "WorkScheduler");
66 : #endif
67 :
68 46 : if (schedule_work_impl(device, batch_size))
69 : {
70 : #ifdef NEML2_HAS_JSON
71 27 : if (event_tracing_enabled())
72 0 : event_trace_writer().trace_duration_end(
73 0 : "schedule work", "WorkScheduler", to_json(device, batch_size), 0);
74 : #endif
75 27 : return;
76 : }
77 19 : _condition.wait(lock,
78 38 : [this, &device, &batch_size] { return schedule_work_impl(device, batch_size); });
79 :
80 : #ifdef NEML2_HAS_JSON
81 19 : if (event_tracing_enabled())
82 0 : event_trace_writer().trace_duration_end(
83 0 : "schedule work", "WorkScheduler", to_json(device, batch_size), 0);
84 : #endif
85 46 : }
86 :
87 : void
88 28 : WorkScheduler::dispatched_work(Device device, std::size_t m)
89 : {
90 : #ifdef NEML2_HAS_JSON
91 28 : if (event_tracing_enabled())
92 0 : event_trace_writer().trace_duration_begin("dispatch work", "WorkScheduler");
93 : #endif
94 :
95 28 : std::lock_guard<std::mutex> lock(_mutex);
96 28 : dispatched_work_impl(device, m);
97 :
98 : #ifdef NEML2_HAS_JSON
99 28 : if (event_tracing_enabled())
100 0 : event_trace_writer().trace_duration_end("dispatch work", "WorkScheduler", to_json(device, m));
101 : #endif
102 28 : }
103 :
104 : void
105 27 : WorkScheduler::completed_work(Device device, std::size_t m)
106 : {
107 : #ifdef NEML2_HAS_JSON
108 27 : if (event_tracing_enabled())
109 0 : event_trace_writer().trace_instant("completed work", "WorkScheduler", to_json(device, m));
110 : #endif
111 27 : std::lock_guard<std::mutex> lock(_mutex);
112 27 : completed_work_impl(device, m);
113 27 : _condition.notify_all();
114 27 : }
115 :
116 : void
117 4 : WorkScheduler::wait_for_completion()
118 : {
119 4 : std::unique_lock<std::mutex> lock(_mutex);
120 4 : if (all_work_completed())
121 : {
122 : #ifdef NEML2_HAS_JSON
123 0 : if (event_tracing_enabled())
124 0 : event_trace_writer().trace_instant("all work completed", "WorkScheduler");
125 : #endif
126 0 : return;
127 : }
128 12 : _condition.wait(lock, [this] { return all_work_completed(); });
129 :
130 : #ifdef NEML2_HAS_JSON
131 4 : if (event_tracing_enabled())
132 0 : event_trace_writer().trace_instant("all work completed", "WorkScheduler");
133 : #endif
134 4 : }
135 : }
|