29#include "neml2/dispatchers/WorkScheduler.h"
30#include "neml2/base/Registry.h"
31#include "neml2/base/NEML2Object.h"
32#include "neml2/base/Factory.h"
64 std::vector<Device>
devices()
const override {
return {_available_devices[_device_index]}; }
66 virtual MPI_Comm &
comm() {
return _comm; }
78 void determine_my_device();
81 std::vector<Device> _available_devices;
84 std::vector<std::size_t> _batch_sizes;
87 std::vector<std::size_t> _capacities;
93 std::size_t _device_index = 0;
96 std::size_t _load = 0;
99#define neml2_call_mpi(call) \
103 if (err != MPI_SUCCESS) \
105 char err_string[MPI_MAX_ERROR_STRING]; \
107 MPI_Error_string(err, err_string, &len); \
108 throw NEMLException(std::string("MPI error: ") + err_string); \
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:52
bool schedule_work_impl(Device &, std::size_t &) const override
Implementation of the work scheduling.
void completed_work_impl(Device, std::size_t) override
Update the scheduler with the completion of the last batch.
virtual void set_comm(MPI_Comm comm)
Definition SimpleMPIScheduler.h:68
void dispatched_work_impl(Device, std::size_t) override
Update the scheduler with the dispatch of the last batch.
std::vector< Device > devices() const override
Device options.
Definition SimpleMPIScheduler.h:64
virtual MPI_Comm & comm()
Definition SimpleMPIScheduler.h:66
void setup() override
Check the device list and coordinate this rank's device.
static OptionSet expected_options()
Options for the scheduler.
SimpleMPIScheduler(const OptionSet &options)
Construct a new WorkScheduler object.
bool all_work_completed() const override
Check if all work has been completed.
WorkScheduler(const OptionSet &options)
Construct a new WorkScheduler object.
Definition DiagnosticsInterface.h:31
c10::Device Device
Definition types.h:69