Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/drivers/ModelDriver.h"
26 : #include "neml2/misc/assertions.h"
27 : #include "neml2/models/Model.h"
28 :
29 : #ifdef NEML2_HAS_DISPATCHER
30 : #include "neml2/dispatchers/valuemap_helpers.h"
31 : #endif
32 :
33 : namespace neml2
34 : {
35 : void
36 0 : details_callback(const Model & model,
37 : const std::map<VariableName, std::unique_ptr<VariableBase>> & input,
38 : const std::map<VariableName, std::unique_ptr<VariableBase>> & output)
39 : {
40 0 : std::cout << model.name() << std::endl;
41 0 : std::cout << "\tInput" << std::endl;
42 0 : for (const auto & pair : input)
43 : {
44 0 : std::cout << "\t\t" << pair.first.str() << " (" << pair.second->sizes() << ") -> "
45 0 : << at::norm(pair.second->tensor()).cpu().item<double>() << std::endl;
46 : }
47 0 : std::cout << "\tOutput" << std::endl;
48 0 : for (const auto & pair : output)
49 : {
50 0 : std::cout << "\t\t" << pair.first.str() << " (" << pair.second->sizes() << ") -> "
51 0 : << at::norm(pair.second->tensor()).cpu().item<double>() << std::endl;
52 : }
53 0 : }
54 :
55 : OptionSet
56 8 : ModelDriver::expected_options()
57 : {
58 8 : OptionSet options = Driver::expected_options();
59 :
60 16 : options.set<std::string>("model");
61 16 : options.set("model").doc() = "The material model to be updated by the driver";
62 :
63 16 : options.set<std::string>("device") = "cpu";
64 24 : options.set("device").doc() =
65 : "Device on which to evaluate the material model. The string supplied must follow the "
66 : "following schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, "
67 : "and :<device-index> optionally specifies a device index. For example, device='cpu' sets the "
68 : "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
69 8 : "CUDA with device ID 1.";
70 :
71 16 : options.set<bool>("show_parameters") = false;
72 16 : options.set("show_parameters").doc() = "Whether to show model parameters at the beginning";
73 16 : options.set<bool>("show_input_axis") = false;
74 16 : options.set("show_input_axis").doc() = "Whether to show model input axis at the beginning";
75 16 : options.set<bool>("show_output_axis") = false;
76 16 : options.set("show_output_axis").doc() = "Whether to show model output axis at the beginning";
77 :
78 16 : options.set<bool>("log_details") = false;
79 24 : options.set("log_details").doc() =
80 8 : "If true attach a callback which outputs lots of information on the model execution";
81 :
82 : #ifdef NEML2_HAS_DISPATCHER
83 16 : options.set<std::string>("scheduler");
84 16 : options.set("scheduler").doc() = "The work scheduler to use";
85 16 : options.set<bool>("async_dispatch") = true;
86 8 : options.set("async_dispatch").doc() = "Whether to dispatch work asynchronously";
87 : #endif
88 :
89 8 : return options;
90 0 : }
91 :
92 7 : ModelDriver::ModelDriver(const OptionSet & options)
93 : : Driver(options),
94 7 : _model(get_model("model")),
95 21 : _device(options.get<std::string>("device")),
96 14 : _show_params(options.get<bool>("show_parameters")),
97 14 : _show_input(options.get<bool>("show_input_axis")),
98 14 : _show_output(options.get<bool>("show_output_axis")),
99 14 : _log_details(options.get<bool>("log_details"))
100 : #ifdef NEML2_HAS_DISPATCHER
101 : ,
102 7 : _scheduler(options.get("scheduler").user_specified() ? get_scheduler("scheduler") : nullptr),
103 28 : _async_dispatch(options.get<bool>("async_dispatch"))
104 : #endif
105 : {
106 7 : }
107 :
108 : void
109 7 : ModelDriver::setup()
110 : {
111 7 : Driver::setup();
112 :
113 : // Send model parameters and buffers to device
114 7 : _model->to(_device);
115 :
116 7 : if (_log_details)
117 0 : _model->register_callback_recursive(details_callback);
118 :
119 : #ifdef NEML2_HAS_DISPATCHER
120 7 : if (_scheduler)
121 : {
122 0 : auto red = [](std::vector<ValueMap> && results) -> ValueMap
123 0 : { return valuemap_cat_reduce(std::move(results), 0); };
124 :
125 0 : auto post = [this](ValueMap && x) -> ValueMap
126 0 : { return valuemap_move_device(std::move(x), _device); };
127 :
128 0 : auto thread_init = [this](Device device) -> void
129 : {
130 0 : auto new_factory = Factory(_model->factory()->input_file());
131 0 : auto new_model = new_factory.get_model(_model->name());
132 0 : new_model->to(device);
133 0 : _models[std::this_thread::get_id()] = std::move(new_model);
134 0 : };
135 :
136 0 : _dispatcher = std::make_unique<DispatcherType>(
137 0 : *_scheduler,
138 0 : _async_dispatch,
139 0 : [&](ValueMap && x, Device device) -> ValueMap
140 : {
141 0 : auto & model = _async_dispatch ? _models[std::this_thread::get_id()] : _model;
142 :
143 : // If this is not an async dispatch, we need to move the model to the target device
144 : // _every_ time before evaluation
145 0 : if (!_async_dispatch)
146 0 : model->to(device);
147 :
148 0 : neml_assert_dbg(model->variable_options().device() == device);
149 0 : return model->value(std::move(x));
150 : },
151 : red,
152 0 : &valuemap_move_device,
153 : post,
154 0 : _async_dispatch ? thread_init : std::function<void(Device)>());
155 : }
156 : #endif
157 :
158 : // LCOV_EXCL_START
159 : if (_show_input)
160 : std::cout << _model->name() << "'s input axis:\n" << _model->input_axis() << std::endl;
161 :
162 : if (_show_output)
163 : std::cout << _model->name() << "'s output axis:\n" << _model->output_axis() << std::endl;
164 :
165 : if (_show_params)
166 : {
167 : std::cout << _model->name() << "'s parameters:" << std::endl;
168 : for (auto && [pname, pval] : _model->named_parameters())
169 : std::cout << " " << pname << std::endl;
170 : }
171 : // LCOV_EXCL_STOP
172 7 : }
173 :
174 : void
175 7 : ModelDriver::diagnose() const
176 : {
177 7 : Driver::diagnose();
178 7 : neml2::diagnose(*_model);
179 7 : }
180 :
181 : void
182 0 : ModelDriver::to(Device dev)
183 : {
184 0 : _device = dev;
185 0 : setup();
186 0 : }
187 :
188 : } // namespace neml2
|