LCOV - code coverage report
Current view: top level - dispatchers - StaticHybridScheduler.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 72.5 % 91 66
Test Date: 2025-10-02 16:03:03 Functions: 62.5 % 8 5

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include <set>
      26              : 
      27              : #include "neml2/dispatchers/StaticHybridScheduler.h"
      28              : #include "neml2/misc/assertions.h"
      29              : #include "neml2/base/Registry.h"
      30              : 
      31              : namespace neml2
      32              : {
      33              : register_NEML2_object(StaticHybridScheduler);
      34              : 
      35              : OptionSet
      36            4 : StaticHybridScheduler::expected_options()
      37              : {
      38            4 :   OptionSet options = WorkScheduler::expected_options();
      39            4 :   options.doc() = "Dispatch work to multiple devices based on provided batch sizes and priorities.";
      40              : 
      41            8 :   options.set<std::vector<Device>>("devices");
      42            8 :   options.set("devices").doc() = "List of devices to dispatch work to";
      43              : 
      44            8 :   options.set<std::vector<std::size_t>>("batch_sizes");
      45            4 :   options.set("batch_sizes").doc() = "List of batch sizes for each device";
      46              : 
      47           12 :   options.set<std::vector<std::size_t>>("capacities") = {};
      48            4 :   options.set("capacities").doc() = "List of capacities for each device, default to batch_sizes";
      49              : 
      50           12 :   options.set<std::vector<double>>("priorities") = {};
      51            4 :   options.set("priorities").doc() = "List of priorities for each device";
      52              : 
      53            4 :   return options;
      54            0 : }
      55              : 
      56            2 : StaticHybridScheduler::StaticHybridScheduler(const OptionSet & options)
      57              :   : WorkScheduler(options),
      58            6 :     _available_devices(options.get<std::vector<Device>>("devices"))
      59              : {
      60            2 : }
      61              : 
      62              : void
      63            2 : StaticHybridScheduler::setup()
      64              : {
      65            2 :   WorkScheduler::setup();
      66              : 
      67            4 :   const auto & device_list = input_options().get<std::vector<Device>>("devices");
      68            4 :   const auto & batch_sizes = input_options().get<std::vector<std::size_t>>("batch_sizes");
      69            4 :   const auto & capacities = input_options().get<std::vector<std::size_t>>("capacities");
      70            4 :   const auto & priorities = input_options().get<std::vector<double>>("priorities");
      71              : 
      72              :   // First pass:
      73              :   // - Check if any CPU device is present
      74              :   // - Check if any CUDA device is present
      75              :   // - Make sure no more than one CPU device is present
      76            6 :   for (const auto & device : device_list)
      77            4 :     if (device.is_cpu())
      78              :     {
      79            2 :       neml_assert(!_cpu, "Multiple CPU devices are not allowed");
      80            2 :       _cpu = true;
      81              :     }
      82            2 :     else if (device.is_cuda())
      83            2 :       _cuda = true;
      84              :     else
      85            0 :       neml_assert(false, "Unsupported device type: ", device);
      86              : 
      87              :   // Second pass:
      88              :   // - If multiple CUDA devices are present, make sure each CUDA device has a concrete
      89              :   //   (nonnegative), unique device ID
      90            2 :   bool has_multiple_cuda_devices = _cpu ? device_list.size() > 2 : device_list.size() > 1;
      91            2 :   if (has_multiple_cuda_devices)
      92              :   {
      93            1 :     std::set<DeviceIndex> cuda_device_ids;
      94            4 :     for (const auto & device : device_list)
      95              :     {
      96            3 :       if (device.is_cpu())
      97            1 :         continue;
      98            2 :       auto device_id = device.index();
      99            2 :       neml_assert(device_id >= 0, "Device ID must be nonnegative");
     100            2 :       neml_assert(cuda_device_ids.find(device_id) == cuda_device_ids.end(),
     101              :                   "Device ID must be unique. Found duplicate: ",
     102              :                   device_id);
     103            2 :       cuda_device_ids.insert(device_id);
     104              :     }
     105            1 :   }
     106              : 
     107              :   // Expand batch size if necessary
     108            2 :   auto batch_sizes_expand = batch_sizes;
     109            2 :   if (batch_sizes.size() == 1)
     110            1 :     batch_sizes_expand.resize(device_list.size(), batch_sizes[0]);
     111              :   else
     112            1 :     neml_assert(batch_sizes.size() == device_list.size(),
     113              :                 "Number of batch sizes must either be one or match the number of devices.");
     114              : 
     115              :   // Expand capacity if necessary
     116            2 :   auto capacities_expand = capacities;
     117            2 :   if (capacities.empty())
     118            0 :     capacities_expand = batch_sizes_expand;
     119            2 :   else if (capacities.size() == 1)
     120            1 :     capacities_expand.resize(device_list.size(), capacities[0]);
     121              :   else
     122            1 :     neml_assert(capacities.size() == device_list.size(),
     123              :                 "Number of capacities must either be zero, one, or match the number of devices.");
     124              : 
     125              :   // Expand priorities if necessary
     126            2 :   auto priorities_expand = priorities;
     127            2 :   if (priorities.empty())
     128            0 :     priorities_expand.resize(device_list.size(), 1.0);
     129              :   else
     130            2 :     neml_assert(priorities.size() == device_list.size(),
     131              :                 "Number of priorities must match the number of devices.");
     132              : 
     133              :   // Construct the device status list
     134            6 :   for (std::size_t i = 0; i < device_list.size(); ++i)
     135            4 :     _devices.emplace_back(
     136            4 :         device_list[i], batch_sizes_expand[i], capacities_expand[i], priorities_expand[i]);
     137            2 : }
     138              : 
     139              : void
     140            0 : StaticHybridScheduler::set_availability_calculator(std::function<double(const DeviceStatus &)> f)
     141              : {
     142            0 :   _custom_availability_calculator = std::move(f);
     143            0 : }
     144              : 
     145              : bool
     146            0 : StaticHybridScheduler::schedule_work_impl(Device & device, std::size_t & n) const
     147              : {
     148            0 :   bool available = false;
     149            0 :   double max_availability = std::numeric_limits<double>::lowest();
     150              : 
     151            0 :   for (const auto & i : _devices)
     152            0 :     if ((i.load + i.batch_size) <= i.capacity)
     153              :     {
     154              :       auto availability =
     155            0 :           _custom_availability_calculator ? _custom_availability_calculator(i) : i.priority;
     156            0 :       if (!available || availability > max_availability)
     157              :       {
     158            0 :         available = true;
     159            0 :         device = i.device;
     160            0 :         n = i.batch_size;
     161              :       }
     162              :     }
     163              : 
     164            0 :   return available;
     165              : }
     166              : 
     167              : void
     168            5 : StaticHybridScheduler::dispatched_work_impl(Device device, std::size_t n)
     169              : {
     170            8 :   for (auto & i : _devices)
     171            8 :     if (i.device == device)
     172              :     {
     173            5 :       i.load += n;
     174              :       // TODO: Add an option to allow for oversubscription, maybe?
     175            5 :       neml_assert(i.load <= i.capacity, "Device oversubscribed");
     176            5 :       return;
     177              :     }
     178              : 
     179            0 :   neml_assert(false, "Device not found: ", device);
     180              : }
     181              : 
     182              : void
     183            4 : StaticHybridScheduler::completed_work_impl(Device device, std::size_t n)
     184              : {
     185            7 :   for (auto & i : _devices)
     186            7 :     if (i.device == device)
     187              :     {
     188            4 :       neml_assert(i.load >= n, "Device load underflow");
     189            4 :       i.load -= n;
     190            4 :       return;
     191              :     }
     192              : 
     193            0 :   neml_assert(false, "Device not found: ", device);
     194              : }
     195              : 
     196              : bool
     197            0 : StaticHybridScheduler::all_work_completed() const
     198              : {
     199            0 :   for (const auto & i : _devices)
     200            0 :     if (i.load > 0)
     201            0 :       return false;
     202            0 :   return true;
     203              : }
     204              : } // namespace neml2
        

Generated by: LCOV version 2.0-1