Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include <set>
26 :
27 : #include "neml2/dispatchers/StaticHybridScheduler.h"
28 : #include "neml2/misc/assertions.h"
29 : #include "neml2/base/Registry.h"
30 :
31 : namespace neml2
32 : {
33 : register_NEML2_object(StaticHybridScheduler);
34 :
35 : OptionSet
36 4 : StaticHybridScheduler::expected_options()
37 : {
38 4 : OptionSet options = WorkScheduler::expected_options();
39 4 : options.doc() = "Dispatch work to multiple devices based on provided batch sizes and priorities.";
40 :
41 8 : options.set<std::vector<Device>>("devices");
42 8 : options.set("devices").doc() = "List of devices to dispatch work to";
43 :
44 8 : options.set<std::vector<std::size_t>>("batch_sizes");
45 4 : options.set("batch_sizes").doc() = "List of batch sizes for each device";
46 :
47 12 : options.set<std::vector<std::size_t>>("capacities") = {};
48 4 : options.set("capacities").doc() = "List of capacities for each device, default to batch_sizes";
49 :
50 12 : options.set<std::vector<double>>("priorities") = {};
51 4 : options.set("priorities").doc() = "List of priorities for each device";
52 :
53 4 : return options;
54 0 : }
55 :
56 2 : StaticHybridScheduler::StaticHybridScheduler(const OptionSet & options)
57 : : WorkScheduler(options),
58 6 : _available_devices(options.get<std::vector<Device>>("devices"))
59 : {
60 2 : }
61 :
62 : void
63 2 : StaticHybridScheduler::setup()
64 : {
65 2 : WorkScheduler::setup();
66 :
67 4 : const auto & device_list = input_options().get<std::vector<Device>>("devices");
68 4 : const auto & batch_sizes = input_options().get<std::vector<std::size_t>>("batch_sizes");
69 4 : const auto & capacities = input_options().get<std::vector<std::size_t>>("capacities");
70 4 : const auto & priorities = input_options().get<std::vector<double>>("priorities");
71 :
72 : // First pass:
73 : // - Check if any CPU device is present
74 : // - Check if any CUDA device is present
75 : // - Make sure no more than one CPU device is present
76 6 : for (const auto & device : device_list)
77 4 : if (device.is_cpu())
78 : {
79 2 : neml_assert(!_cpu, "Multiple CPU devices are not allowed");
80 2 : _cpu = true;
81 : }
82 2 : else if (device.is_cuda())
83 2 : _cuda = true;
84 : else
85 0 : neml_assert(false, "Unsupported device type: ", device);
86 :
87 : // Second pass:
88 : // - If multiple CUDA devices are present, make sure each CUDA device has a concrete
89 : // (nonnegative), unique device ID
90 2 : bool has_multiple_cuda_devices = _cpu ? device_list.size() > 2 : device_list.size() > 1;
91 2 : if (has_multiple_cuda_devices)
92 : {
93 1 : std::set<DeviceIndex> cuda_device_ids;
94 4 : for (const auto & device : device_list)
95 : {
96 3 : if (device.is_cpu())
97 1 : continue;
98 2 : auto device_id = device.index();
99 2 : neml_assert(device_id >= 0, "Device ID must be nonnegative");
100 2 : neml_assert(cuda_device_ids.find(device_id) == cuda_device_ids.end(),
101 : "Device ID must be unique. Found duplicate: ",
102 : device_id);
103 2 : cuda_device_ids.insert(device_id);
104 : }
105 1 : }
106 :
107 : // Expand batch size if necessary
108 2 : auto batch_sizes_expand = batch_sizes;
109 2 : if (batch_sizes.size() == 1)
110 1 : batch_sizes_expand.resize(device_list.size(), batch_sizes[0]);
111 : else
112 1 : neml_assert(batch_sizes.size() == device_list.size(),
113 : "Number of batch sizes must either be one or match the number of devices.");
114 :
115 : // Expand capacity if necessary
116 2 : auto capacities_expand = capacities;
117 2 : if (capacities.empty())
118 0 : capacities_expand = batch_sizes_expand;
119 2 : else if (capacities.size() == 1)
120 1 : capacities_expand.resize(device_list.size(), capacities[0]);
121 : else
122 1 : neml_assert(capacities.size() == device_list.size(),
123 : "Number of capacities must either be zero, one, or match the number of devices.");
124 :
125 : // Expand priorities if necessary
126 2 : auto priorities_expand = priorities;
127 2 : if (priorities.empty())
128 0 : priorities_expand.resize(device_list.size(), 1.0);
129 : else
130 2 : neml_assert(priorities.size() == device_list.size(),
131 : "Number of priorities must match the number of devices.");
132 :
133 : // Construct the device status list
134 6 : for (std::size_t i = 0; i < device_list.size(); ++i)
135 4 : _devices.emplace_back(
136 4 : device_list[i], batch_sizes_expand[i], capacities_expand[i], priorities_expand[i]);
137 2 : }
138 :
139 : void
140 0 : StaticHybridScheduler::set_availability_calculator(std::function<double(const DeviceStatus &)> f)
141 : {
142 0 : _custom_availability_calculator = std::move(f);
143 0 : }
144 :
145 : bool
146 0 : StaticHybridScheduler::schedule_work_impl(Device & device, std::size_t & n) const
147 : {
148 0 : bool available = false;
149 0 : double max_availability = std::numeric_limits<double>::lowest();
150 :
151 0 : for (const auto & i : _devices)
152 0 : if ((i.load + i.batch_size) <= i.capacity)
153 : {
154 : auto availability =
155 0 : _custom_availability_calculator ? _custom_availability_calculator(i) : i.priority;
156 0 : if (!available || availability > max_availability)
157 : {
158 0 : available = true;
159 0 : device = i.device;
160 0 : n = i.batch_size;
161 : }
162 : }
163 :
164 0 : return available;
165 : }
166 :
167 : void
168 5 : StaticHybridScheduler::dispatched_work_impl(Device device, std::size_t n)
169 : {
170 8 : for (auto & i : _devices)
171 8 : if (i.device == device)
172 : {
173 5 : i.load += n;
174 : // TODO: Add an option to allow for oversubscription, maybe?
175 5 : neml_assert(i.load <= i.capacity, "Device oversubscribed");
176 5 : return;
177 : }
178 :
179 0 : neml_assert(false, "Device not found: ", device);
180 : }
181 :
182 : void
183 4 : StaticHybridScheduler::completed_work_impl(Device device, std::size_t n)
184 : {
185 7 : for (auto & i : _devices)
186 7 : if (i.device == device)
187 : {
188 4 : neml_assert(i.load >= n, "Device load underflow");
189 4 : i.load -= n;
190 4 : return;
191 : }
192 :
193 0 : neml_assert(false, "Device not found: ", device);
194 : }
195 :
196 : bool
197 0 : StaticHybridScheduler::all_work_completed() const
198 : {
199 0 : for (const auto & i : _devices)
200 0 : if (i.load > 0)
201 0 : return false;
202 0 : return true;
203 : }
204 : } // namespace neml2
|