Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include <c10/core/InferenceMode.h>
26 : #include <torch/csrc/jit/frontend/tracer.h>
27 :
28 : #include "neml2/misc/assertions.h"
29 : #include "neml2/base/guards.h"
30 : #include "neml2/base/Factory.h"
31 : #include "neml2/base/Settings.h"
32 : #include "neml2/jit/utils.h"
33 : #include "neml2/tensors/functions/jacrev.h"
34 : #include "neml2/tensors/tensors.h"
35 : #include "neml2/tensors/TensorValue.h"
36 : #include "neml2/models/Model.h"
37 : #include "neml2/models/Assembler.h"
38 : #include "neml2/models/map_types_fwd.h"
39 :
40 : namespace neml2
41 : {
42 : std::shared_ptr<Model>
43 15 : load_model(const std::filesystem::path & path, const std::string & mname)
44 : {
45 15 : auto factory = load_input(path);
46 30 : return factory->get_model(mname);
47 15 : }
48 :
49 : bool
50 0 : Model::TraceSchema::operator==(const TraceSchema & other) const
51 : {
52 0 : return batch_dims == other.batch_dims && dispatch_key == other.dispatch_key;
53 : }
54 :
55 : bool
56 13028 : Model::TraceSchema::operator<(const TraceSchema & other) const
57 : {
58 13028 : if (dispatch_key != other.dispatch_key)
59 0 : return dispatch_key < other.dispatch_key;
60 13028 : return batch_dims < other.batch_dims;
61 : }
62 :
63 : OptionSet
64 481 : Model::expected_options()
65 : {
66 481 : OptionSet options = Data::expected_options();
67 481 : options += NonlinearSystem::expected_options();
68 481 : NonlinearSystem::disable_automatic_scaling(options);
69 :
70 481 : options.section() = "Models";
71 :
72 : // Model defaults to defining value and dvalue, but not d2value
73 962 : options.set<bool>("define_values") = true;
74 962 : options.set<bool>("define_derivatives") = true;
75 962 : options.set<bool>("define_second_derivatives") = false;
76 962 : options.set("define_values").suppressed() = true;
77 962 : options.set("define_derivatives").suppressed() = true;
78 962 : options.set("define_second_derivatives").suppressed() = true;
79 :
80 : // Model defaults to _not_ being part of a nonlinear system
81 : // Model::get_model will set this to true if the model is expected to be part of a nonlinear
82 : // system, and additional diagnostics will be performed
83 962 : options.set<bool>("_nonlinear_system") = false;
84 962 : options.set("_nonlinear_system").suppressed() = true;
85 :
86 962 : options.set<bool>("jit") = true;
87 962 : options.set("jit").doc() = "Use JIT compilation for the forward operator";
88 :
89 962 : options.set<bool>("production") = false;
90 962 : options.set("production").doc() =
91 : "Production mode. This option is used to disable features like function graph tracking and "
92 : "tensor version tracking which are useful for training (i.e., calibrating model parameters) "
93 481 : "but are not necessary in production runs.";
94 :
95 481 : return options;
96 0 : }
97 :
98 481 : Model::Model(const OptionSet & options)
99 : : Data(options),
100 : ParameterStore(this),
101 : VariableStore(this),
102 : NonlinearSystem(options),
103 : DiagnosticsInterface(this),
104 962 : _defines_value(options.get<bool>("define_values")),
105 962 : _defines_dvalue(options.get<bool>("define_derivatives")),
106 962 : _defines_d2value(options.get<bool>("define_second_derivatives")),
107 481 : _nonlinear_system(options.get<bool>("_nonlinear_system")),
108 1431 : _jit(settings().disable_jit() ? false : options.get<bool>("jit")),
109 1924 : _production(options.get<bool>("production"))
110 : {
111 481 : }
112 :
113 : void
114 91 : Model::to(const TensorOptions & options)
115 : {
116 91 : send_buffers_to(options);
117 91 : send_parameters_to(options);
118 91 : send_variables_to(options);
119 :
120 174 : for (auto & submodel : registered_models())
121 83 : submodel->to(options);
122 :
123 92 : for (auto & [name, param] : named_nonlinear_parameters())
124 92 : param.provider->to(options);
125 91 : }
126 :
127 : void
128 481 : Model::setup()
129 : {
130 481 : setup_layout();
131 :
132 481 : if (host() == this)
133 : {
134 208 : link_output_variables();
135 208 : link_input_variables();
136 : }
137 :
138 481 : request_AD();
139 481 : }
140 :
141 : void
142 99 : Model::diagnose() const
143 : {
144 187 : for (auto & submodel : registered_models())
145 88 : neml2::diagnose(*submodel);
146 :
147 : // Make sure variables are defined on the reserved subaxes
148 401 : for (auto && [name, var] : input_variables())
149 302 : diagnostic_check_input_variable(*var);
150 219 : for (auto && [name, var] : output_variables())
151 120 : diagnostic_check_output_variable(*var);
152 :
153 99 : if (is_nonlinear_system())
154 5 : diagnose_nl_sys();
155 :
156 99 : if (settings().disable_jit())
157 9 : if (input_options().user_specified("jit"))
158 3 : diagnostic_assert(!input_options().get<bool>("jit"),
159 : "JIT compilation is disabled globally by Settings/disable_jit=true, and it "
160 : "is an error to explicitly set jit to true for any model.");
161 99 : }
162 :
163 : void
164 72 : Model::diagnose_nl_sys() const
165 : {
166 139 : for (auto & submodel : registered_models())
167 67 : submodel->diagnose_nl_sys();
168 :
169 : // Check if any input variable is solve-dependent
170 72 : bool input_solve_dep = false;
171 253 : for (auto && [name, var] : input_variables())
172 181 : if (var->is_solve_dependent())
173 109 : input_solve_dep = true;
174 :
175 : // If any input variable is solve-dependent, ALL output variables must be solve-dependent!
176 72 : if (input_solve_dep)
177 145 : for (auto && [name, var] : output_variables())
178 76 : diagnostic_assert(
179 76 : var->is_solve_dependent(),
180 : "This model is part of a nonlinear system. At least one of the input variables is "
181 : "solve-dependent, so all output variables MUST be solve-dependent, i.e., they must be "
182 : "on one of the following sub-axes: state, residual, parameters. However, got output "
183 : "variable ",
184 : name);
185 72 : }
186 :
187 : void
188 26 : Model::diagnostic_assert_state(const VariableBase & v) const
189 : {
190 26 : diagnostic_assert(v.is_state(), "Variable ", v.name(), " must be on the ", STATE, " sub-axis.");
191 26 : }
192 :
193 : void
194 0 : Model::diagnostic_assert_old_state(const VariableBase & v) const
195 : {
196 0 : diagnostic_assert(
197 0 : v.is_old_state(), "Variable ", v.name(), " must be on the ", OLD_STATE, " sub-axis.");
198 0 : }
199 :
200 : void
201 19 : Model::diagnostic_assert_force(const VariableBase & v) const
202 : {
203 19 : diagnostic_assert(v.is_force(), "Variable ", v.name(), " must be on the ", FORCES, " sub-axis.");
204 19 : }
205 :
206 : void
207 0 : Model::diagnostic_assert_old_force(const VariableBase & v) const
208 : {
209 0 : diagnostic_assert(
210 0 : v.is_old_force(), "Variable ", v.name(), " must be on the ", OLD_FORCES, " sub-axis.");
211 0 : }
212 :
213 : void
214 0 : Model::diagnostic_assert_residual(const VariableBase & v) const
215 : {
216 0 : diagnostic_assert(
217 0 : v.is_residual(), "Variable ", v.name(), " must be on the ", RESIDUAL, " sub-axis.");
218 0 : }
219 :
220 : void
221 302 : Model::diagnostic_check_input_variable(const VariableBase & v) const
222 : {
223 770 : diagnostic_assert(v.is_state() || v.is_old_state() || v.is_force() || v.is_old_force() ||
224 468 : v.is_residual() || v.is_parameter(),
225 : "Input variable ",
226 302 : v.name(),
227 : " must be on one of the following sub-axes: ",
228 : STATE,
229 : ", ",
230 : OLD_STATE,
231 : ", ",
232 : FORCES,
233 : ", ",
234 : OLD_FORCES,
235 : ", ",
236 : RESIDUAL,
237 : ", ",
238 : PARAMETERS,
239 : ".");
240 302 : }
241 :
242 : void
243 120 : Model::diagnostic_check_output_variable(const VariableBase & v) const
244 : {
245 120 : diagnostic_assert(v.is_state() || v.is_force() || v.is_residual() || v.is_parameter(),
246 : "Output variable ",
247 120 : v.name(),
248 : " must be on one of the following sub-axes: ",
249 : STATE,
250 : ", ",
251 : FORCES,
252 : ", ",
253 : RESIDUAL,
254 : ", ",
255 : PARAMETERS,
256 : ".");
257 120 : }
258 :
259 : void
260 482 : Model::link_input_variables()
261 : {
262 756 : for (auto & submodel : _registered_models)
263 : {
264 274 : link_input_variables(submodel.get());
265 274 : submodel->link_input_variables();
266 : }
267 482 : }
268 :
269 : void
270 14 : Model::link_input_variables(Model * submodel)
271 : {
272 98 : for (auto && [name, var] : submodel->input_variables())
273 84 : var->ref(input_variable(name), submodel->is_nonlinear_system());
274 14 : }
275 :
276 : void
277 482 : Model::link_output_variables()
278 : {
279 756 : for (auto & submodel : _registered_models)
280 : {
281 274 : link_output_variables(submodel.get());
282 274 : submodel->link_output_variables();
283 : }
284 482 : }
285 :
286 : void
287 14 : Model::link_output_variables(Model * /*submodel*/)
288 : {
289 14 : }
290 :
291 : void
292 14 : Model::request_AD(VariableBase & y, const VariableBase & u)
293 : {
294 14 : neml_assert(_defines_value,
295 : "Model of type '",
296 14 : type(),
297 : "' is requesting automatic differentiation of first derivatives, but it does not "
298 : "define output values.");
299 14 : _defines_dvalue = true;
300 14 : _ad_derivs[&y].insert(&u);
301 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
302 14 : _ad_args.insert(const_cast<VariableBase *>(&u));
303 14 : }
304 :
305 : void
306 48 : Model::request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2)
307 : {
308 48 : neml_assert(_defines_dvalue,
309 : "Model of type '",
310 48 : type(),
311 : "' is requesting automatic differentiation of second derivatives, but it does not "
312 : "define first derivatives.");
313 48 : _defines_d2value = true;
314 48 : _ad_secderivs[&y][&u1].insert(&u2);
315 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
316 48 : _ad_args.insert(const_cast<VariableBase *>(&u2));
317 48 : }
318 :
319 : void
320 9219 : Model::clear_input()
321 : {
322 9219 : VariableStore::clear_input();
323 14646 : for (auto & submodel : _registered_models)
324 5427 : submodel->clear_input();
325 9219 : }
326 :
327 : void
328 9219 : Model::clear_output()
329 : {
330 9219 : VariableStore::clear_output();
331 14646 : for (auto & submodel : _registered_models)
332 5427 : submodel->clear_output();
333 9219 : }
334 :
335 : void
336 9221 : Model::zero_input()
337 : {
338 9221 : VariableStore::zero_input();
339 14648 : for (auto & submodel : _registered_models)
340 5427 : submodel->zero_input();
341 9221 : }
342 :
343 : void
344 9221 : Model::zero_output()
345 : {
346 9221 : VariableStore::zero_output();
347 14648 : for (auto & submodel : _registered_models)
348 5427 : submodel->zero_output();
349 9221 : }
350 :
351 : Model::TraceSchema
352 6968 : Model::compute_trace_schema() const
353 : {
354 6968 : std::vector<Size> batch_dims;
355 43921 : for (auto && [name, var] : input_variables())
356 36953 : batch_dims.push_back(var->batch_dim());
357 23609 : for (auto && [name, param] : host<ParameterStore>()->named_parameters())
358 16641 : batch_dims.push_back(Tensor(*param).batch_dim());
359 :
360 6968 : const auto dispatch_key = variable_options().computeDispatchKey();
361 :
362 13936 : return TraceSchema{batch_dims, dispatch_key};
363 6968 : }
364 :
365 : std::size_t
366 6968 : Model::forward_operator_index(bool out, bool dout, bool d2out) const
367 : {
368 6968 : return (out ? 4 : 0) + (dout ? 2 : 0) + (d2out ? 1 : 0);
369 : }
370 :
371 : void
372 1213 : Model::forward(bool out, bool dout, bool d2out)
373 : {
374 1213 : neml_assert_dbg(defines_values() || (defines_values() == out),
375 : "Model of type '",
376 1213 : type(),
377 : "' is requested to compute output values, but it does not define them.");
378 1213 : neml_assert_dbg(defines_derivatives() || (defines_derivatives() == dout),
379 : "Model of type '",
380 1213 : type(),
381 : "' is requested to compute first derivatives, but it does not define them.");
382 1213 : neml_assert_dbg(defines_second_derivatives() || (defines_second_derivatives() == d2out),
383 : "Model of type '",
384 1213 : type(),
385 : "' is requested to compute second derivatives, but it does not define them.");
386 :
387 1213 : if (dout || d2out)
388 547 : clear_derivatives();
389 :
390 1213 : c10::InferenceMode mode_guard(_production && !jit::tracer::isTracing());
391 :
392 1213 : if (dout || d2out)
393 547 : enable_AD();
394 :
395 1213 : set_value(out || AD_need_value(dout, d2out), dout, d2out);
396 :
397 1213 : if (dout || d2out)
398 547 : extract_AD_derivatives(dout, d2out);
399 :
400 2426 : return;
401 1213 : }
402 :
403 : void
404 7727 : Model::forward_maybe_jit(bool out, bool dout, bool d2out)
405 : {
406 7727 : if (!is_jit_enabled() || jit::tracer::isTracing())
407 : {
408 759 : forward(out, dout, d2out);
409 759 : return;
410 : }
411 :
412 : auto & traced_functions =
413 6968 : currently_solving_nonlinear_system() ? _traced_functions_nl_sys : _traced_functions;
414 :
415 6968 : const auto forward_op_idx = forward_operator_index(out, dout, d2out);
416 6968 : const auto new_schema = compute_trace_schema();
417 6968 : auto traced_schema_and_function = traced_functions[forward_op_idx].find(new_schema);
418 :
419 6968 : if (traced_schema_and_function != traced_functions[forward_op_idx].end())
420 : {
421 6514 : auto & [trace_schema, traced_function] = *traced_schema_and_function;
422 6514 : c10::InferenceMode mode_guard(_production);
423 6514 : auto stack = collect_input_stack();
424 6514 : traced_function->run(stack);
425 6512 : assign_output_stack(stack, out, dout, d2out);
426 6516 : }
427 : else
428 : {
429 : // All other models in the world should wait for this model to finish tracing
430 : // This is not our fault, torch jit tracing is not thread-safe
431 454 : std::shared_ptr<jit::tracer::TracingState> trace;
432 : static std::mutex trace_mutex;
433 454 : trace_mutex.lock();
434 : try
435 : {
436 454 : auto forward_wrap = [&](jit::Stack inputs) -> jit::Stack
437 : {
438 454 : assign_input_stack(inputs);
439 454 : forward(out, dout, d2out);
440 454 : return collect_output_stack(out, dout, d2out);
441 454 : };
442 2270 : trace = std::get<0>(jit::tracer::trace(
443 908 : collect_input_stack(),
444 : forward_wrap,
445 42801 : [this](const ATensor & var) { return variable_name_lookup(var); },
446 : /*strict=*/false,
447 454 : /*force_outplace=*/false));
448 : }
449 0 : catch (const std::exception & e)
450 : {
451 0 : trace_mutex.unlock();
452 0 : throw NEMLException("Failed to trace model '" + name() + "': " + e.what());
453 0 : }
454 454 : trace_mutex.unlock();
455 :
456 908 : auto new_function = std::make_unique<jit::GraphFunction>(name() + ".forward",
457 454 : trace->graph,
458 0 : /*function_creator=*/nullptr,
459 908 : jit::ExecutorExecutionMode::PROFILING);
460 454 : traced_functions[forward_op_idx].emplace(new_schema, std::move(new_function));
461 :
462 : // Rerun this method -- this time using the jitted graph (without tracing)
463 454 : forward_maybe_jit(out, dout, d2out);
464 454 : }
465 6968 : }
466 :
467 : std::string
468 160272 : Model::variable_name_lookup(const ATensor & var)
469 : {
470 : // Look for the variable in the input and output variables
471 566448 : for (auto && [ivar, val] : input_variables())
472 407830 : if (val->tensor().data_ptr() == var.data_ptr())
473 1654 : return name() + "::" + utils::stringify(ivar);
474 340697 : for (auto && [ovar, val] : output_variables())
475 182420 : if (val->tensor().data_ptr() == var.data_ptr())
476 341 : return name() + "::" + utils::stringify(ovar);
477 :
478 : // Look for the variable in the parameter and buffer store
479 1101077 : for (auto && [pname, pval] : host<ParameterStore>()->named_parameters())
480 943646 : if (Tensor(*pval).data_ptr() == var.data_ptr())
481 846 : return name() + "::" + utils::stringify(pname);
482 858975 : for (auto && [bname, bval] : host<BufferStore>()->named_buffers())
483 701705 : if (Tensor(*bval).data_ptr() == var.data_ptr())
484 161 : return name() + "::" + utils::stringify(bname);
485 :
486 : // Look for the variable in the registered models
487 275253 : for (auto & submodel : registered_models())
488 : {
489 118379 : auto name = submodel->variable_name_lookup(var);
490 118379 : if (!name.empty())
491 396 : return name;
492 118379 : }
493 :
494 313748 : return "";
495 : }
496 :
497 : void
498 3794 : Model::check_precision() const
499 : {
500 3794 : if (settings().require_double_precision())
501 3794 : neml_assert(
502 7588 : default_tensor_options().dtype() == kFloat64,
503 : "By default, NEML2 requires double precision for all computations. Please set the default "
504 : "dtype to Float64. In Python, this can be done by calling "
505 : "`torch.set_default_dtype(torch.double)`. In C++, this can be done by calling "
506 : "`neml2::set_default_dtype(neml2::kFloat64)`. If other precisions are truly needed, you "
507 : "can disable this error check with Settings/require_double_precision=false.");
508 3794 : }
509 :
510 : ValueMap
511 2678 : Model::value(const ValueMap & in)
512 : {
513 2678 : forward_helper(in, true, false, false);
514 :
515 2676 : auto values = collect_output();
516 2676 : clear_input();
517 2676 : clear_output();
518 2676 : return values;
519 0 : }
520 :
521 : ValueMap
522 2 : Model::value(ValueMap && in)
523 : {
524 2 : forward_helper(std::move(in), true, false, false);
525 :
526 2 : auto values = collect_output();
527 2 : clear_input();
528 2 : clear_output();
529 2 : return values;
530 0 : }
531 :
532 : std::tuple<ValueMap, DerivMap>
533 0 : Model::value_and_dvalue(const ValueMap & in)
534 : {
535 0 : forward_helper(in, true, true, false);
536 :
537 0 : const auto values = collect_output();
538 0 : const auto derivs = collect_output_derivatives();
539 0 : clear_input();
540 0 : clear_output();
541 0 : return {values, derivs};
542 0 : }
543 :
544 : std::tuple<ValueMap, DerivMap>
545 0 : Model::value_and_dvalue(ValueMap && in)
546 : {
547 0 : forward_helper(std::move(in), true, true, false);
548 :
549 0 : const auto values = collect_output();
550 0 : const auto derivs = collect_output_derivatives();
551 0 : clear_input();
552 0 : clear_output();
553 0 : return {values, derivs};
554 0 : }
555 :
556 : DerivMap
557 1052 : Model::dvalue(const ValueMap & in)
558 : {
559 1052 : forward_helper(in, false, true, false);
560 :
561 1052 : auto derivs = collect_output_derivatives();
562 1052 : clear_input();
563 1052 : clear_output();
564 1052 : return derivs;
565 0 : }
566 :
567 : DerivMap
568 0 : Model::dvalue(ValueMap && in)
569 : {
570 0 : forward_helper(std::move(in), false, true, false);
571 :
572 0 : auto derivs = collect_output_derivatives();
573 0 : clear_input();
574 0 : clear_output();
575 0 : return derivs;
576 0 : }
577 :
578 : std::tuple<ValueMap, DerivMap, SecDerivMap>
579 0 : Model::value_and_dvalue_and_d2value(const ValueMap & in)
580 : {
581 0 : forward_helper(in, true, true, true);
582 :
583 0 : const auto values = collect_output();
584 0 : const auto derivs = collect_output_derivatives();
585 0 : const auto secderivs = collect_output_second_derivatives();
586 0 : clear_input();
587 0 : clear_output();
588 0 : return {values, derivs, secderivs};
589 0 : }
590 :
591 : std::tuple<ValueMap, DerivMap, SecDerivMap>
592 0 : Model::value_and_dvalue_and_d2value(ValueMap && in)
593 : {
594 0 : forward_helper(std::move(in), true, true, true);
595 :
596 0 : const auto values = collect_output();
597 0 : const auto derivs = collect_output_derivatives();
598 0 : const auto secderivs = collect_output_second_derivatives();
599 0 : clear_input();
600 0 : clear_output();
601 0 : return {values, derivs, secderivs};
602 0 : }
603 :
604 : std::tuple<DerivMap, SecDerivMap>
605 0 : Model::dvalue_and_d2value(const ValueMap & in)
606 : {
607 0 : forward_helper(in, false, true, true);
608 :
609 0 : const auto derivs = collect_output_derivatives();
610 0 : const auto secderivs = collect_output_second_derivatives();
611 0 : clear_input();
612 0 : clear_output();
613 0 : return {derivs, secderivs};
614 0 : }
615 :
616 : std::tuple<DerivMap, SecDerivMap>
617 0 : Model::dvalue_and_d2value(ValueMap && in)
618 : {
619 0 : forward_helper(std::move(in), false, true, true);
620 :
621 0 : const auto derivs = collect_output_derivatives();
622 0 : const auto secderivs = collect_output_second_derivatives();
623 0 : clear_input();
624 0 : clear_output();
625 0 : return {derivs, secderivs};
626 0 : }
627 :
628 : SecDerivMap
629 62 : Model::d2value(const ValueMap & in)
630 : {
631 62 : forward_helper(in, false, false, true);
632 :
633 62 : auto secderivs = collect_output_second_derivatives();
634 62 : clear_input();
635 62 : clear_output();
636 62 : return secderivs;
637 0 : }
638 :
639 : SecDerivMap
640 0 : Model::d2value(ValueMap && in)
641 : {
642 0 : forward_helper(std::move(in), false, false, true);
643 :
644 0 : auto secderivs = collect_output_second_derivatives();
645 0 : clear_input();
646 0 : clear_output();
647 0 : return secderivs;
648 0 : }
649 :
650 : std::shared_ptr<Model>
651 260 : Model::registered_model(const std::string & name) const
652 : {
653 861 : for (auto & submodel : _registered_models)
654 861 : if (submodel->name() == name)
655 260 : return submodel;
656 :
657 0 : throw NEMLException("There is no registered model named '" + name + "' in '" + this->name() +
658 0 : "'");
659 : }
660 :
661 : void
662 78 : Model::register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param)
663 : {
664 78 : neml_assert(_nl_params.count(pname) == 0,
665 : "Nonlinear parameter named '",
666 : pname,
667 : "' has already been registered.");
668 78 : _nl_params[pname] = param;
669 78 : }
670 :
671 : bool
672 0 : Model::has_nl_param(bool recursive) const
673 : {
674 0 : if (!recursive)
675 0 : return !_nl_params.empty();
676 :
677 0 : for (auto & submodel : registered_models())
678 0 : if (submodel->has_nl_param(true))
679 0 : return true;
680 :
681 0 : return false;
682 : }
683 :
684 : const VariableBase *
685 385 : Model::nl_param(const std::string & name) const
686 : {
687 385 : return _nl_params.count(name) ? _nl_params.at(name).value : nullptr;
688 : }
689 :
690 : std::map<std::string, NonlinearParameter>
691 597 : Model::named_nonlinear_parameters(bool recursive) const
692 : {
693 597 : if (!recursive)
694 79 : return _nl_params;
695 :
696 518 : auto all_nl_params = _nl_params;
697 :
698 674 : for (const auto & [pname, param] : _nl_params)
699 156 : for (auto && [pname, nl_param] : param.provider->named_nonlinear_parameters(true))
700 0 : all_nl_params[param.provider->name() + settings().parameter_name_separator() + pname] =
701 156 : nl_param;
702 :
703 532 : for (auto & submodel : registered_models())
704 14 : for (auto && [pname, nl_param] : submodel->named_nonlinear_parameters(true))
705 14 : all_nl_params[submodel->name() + settings().parameter_name_separator() + pname] = nl_param;
706 :
707 518 : return all_nl_params;
708 518 : }
709 :
710 : std::set<VariableName>
711 260 : Model::consumed_items() const
712 : {
713 260 : auto items = input_axis().variable_names();
714 520 : return {items.begin(), items.end()};
715 260 : }
716 :
717 : std::set<VariableName>
718 260 : Model::provided_items() const
719 : {
720 260 : auto items = output_axis().variable_names();
721 520 : return {items.begin(), items.end()};
722 260 : }
723 :
724 : void
725 454 : Model::assign_input_stack(jit::Stack & stack)
726 : {
727 : #ifndef NDEBUG
728 454 : const auto nstack = input_axis().nvariable() + host<ParameterStore>()->named_parameters().size();
729 454 : neml_assert_dbg(
730 454 : stack.size() == nstack,
731 : "Stack size (",
732 454 : stack.size(),
733 : ") must equal to the number of input variables, parameters, and buffers in the model (",
734 : nstack,
735 : ").");
736 : #endif
737 :
738 454 : assign_parameter_stack(stack);
739 454 : VariableStore::assign_input_stack(stack);
740 454 : }
741 :
742 : jit::Stack
743 6968 : Model::collect_input_stack() const
744 : {
745 6968 : auto stack = VariableStore::collect_input_stack();
746 6968 : const auto param_stack = collect_parameter_stack();
747 :
748 : // Recall stack is first in last out.
749 : // Parameter stack go after (on top of) input variables. This means that in assign_input_stack
750 : // we need to pop parameters first, then input variables.
751 6968 : stack.insert(stack.end(), param_stack.begin(), param_stack.end());
752 13936 : return stack;
753 6968 : }
754 :
755 : void
756 1572 : Model::set_guess(const Sol<false> & x)
757 : {
758 1572 : const auto sol_assember = VectorAssembler(input_axis().subaxis(STATE));
759 1572 : assign_input(sol_assember.split_by_variable(x));
760 1572 : }
761 :
762 : void
763 2823 : Model::assemble(NonlinearSystem::Res<false> * residual, NonlinearSystem::Jac<false> * Jacobian)
764 : {
765 2823 : forward_maybe_jit(residual, Jacobian, false);
766 :
767 2823 : if (residual)
768 : {
769 1572 : const auto res_assembler = VectorAssembler(output_axis().subaxis(RESIDUAL));
770 1572 : *residual = Res<false>(res_assembler.assemble_by_variable(collect_output()));
771 : }
772 2823 : if (Jacobian)
773 : {
774 : const auto jac_assembler =
775 1251 : MatrixAssembler(output_axis().subaxis(RESIDUAL), input_axis().subaxis(STATE));
776 1251 : *Jacobian = Jac<false>(jac_assembler.assemble_by_variable(collect_output_derivatives()));
777 : }
778 2823 : }
779 :
780 : bool
781 266 : Model::AD_need_value(bool dout, bool d2out) const
782 : {
783 266 : if (dout)
784 203 : if (!_ad_derivs.empty())
785 2 : return true;
786 :
787 264 : if (d2out)
788 66 : for (auto && [y, u1u2s] : _ad_secderivs)
789 0 : for (auto && [u1, u2s] : u1u2s)
790 0 : if (_ad_derivs.count(y) && _ad_derivs.at(y).count(u1))
791 0 : return true;
792 :
793 264 : return false;
794 : }
795 :
796 : void
797 547 : Model::enable_AD()
798 : {
799 553 : for (auto * ad_arg : _ad_args)
800 6 : ad_arg->requires_grad_();
801 547 : }
802 :
803 : void
804 547 : Model::extract_AD_derivatives(bool dout, bool d2out)
805 : {
806 547 : neml_assert(dout || d2out, "At least one of the output derivatives must be requested.");
807 :
808 551 : for (auto && [y, us] : _ad_derivs)
809 : {
810 4 : if (!dout && d2out)
811 0 : if (!_ad_secderivs.count(y))
812 0 : continue;
813 :
814 : // Gather all dependent variables
815 4 : std::vector<Tensor> uts;
816 18 : for (const auto * u : us)
817 14 : if (u->is_dependent())
818 14 : uts.push_back(u->tensor());
819 :
820 : // Check if we need to create the graph (i.e., if any of the second derivatives are requested)
821 4 : bool create_graph = false;
822 18 : for (const auto * u : us)
823 14 : if (u->is_dependent())
824 14 : if (!create_graph && !dout && d2out)
825 0 : if (_ad_secderivs.at(y).count(u))
826 0 : create_graph = true;
827 :
828 4 : const auto dy_dus = jacrev(y->tensor(),
829 : uts,
830 : /*retain_graph=*/true,
831 : /*create_graph=*/create_graph,
832 4 : /*allow_unused=*/true);
833 :
834 4 : std::size_t i = 0;
835 18 : for (const auto * u : us)
836 14 : if (u->is_dependent())
837 : {
838 14 : if (dy_dus[i].defined())
839 14 : y->d(*u) = dy_dus[i];
840 14 : i++;
841 : }
842 4 : }
843 :
844 547 : if (d2out)
845 : {
846 114 : for (auto && [y, u1u2s] : _ad_secderivs)
847 0 : for (auto && [u1, u2s] : u1u2s)
848 : {
849 0 : if (!u1->is_dependent())
850 0 : continue;
851 :
852 0 : const auto & dy_du1 = y->derivatives()[u1->name()];
853 :
854 0 : if (!dy_du1.defined() || !dy_du1.requires_grad())
855 0 : continue;
856 :
857 0 : std::vector<Tensor> u2ts;
858 0 : for (const auto * u2 : u2s)
859 0 : if (u2->is_dependent())
860 0 : u2ts.push_back(u2->tensor());
861 :
862 : const auto d2y_du1u2s = jacrev(dy_du1,
863 : u2ts,
864 : /*retain_graph=*/true,
865 : /*create_graph=*/false,
866 0 : /*allow_unused=*/true);
867 :
868 0 : std::size_t i = 0;
869 0 : for (const auto * u2 : u2s)
870 0 : if (u2->is_dependent())
871 : {
872 0 : if (d2y_du1u2s[i].defined())
873 0 : y->d(*u1, *u2) = d2y_du1u2s[i];
874 0 : i++;
875 : }
876 0 : }
877 : }
878 547 : }
879 :
880 : // LCOV_EXCL_START
881 : std::ostream &
882 : operator<<(std::ostream & os, const Model & model)
883 : {
884 : bool first = false;
885 : const std::string tab = " ";
886 :
887 : os << "Name: " << model.name() << '\n';
888 :
889 : if (!model.input_variables().empty())
890 : {
891 : os << "Input: ";
892 : first = true;
893 : for (auto && [name, var] : model.input_variables())
894 : {
895 : os << (first ? "" : tab);
896 : os << name << " [" << var->type() << "]\n";
897 : first = false;
898 : }
899 : }
900 :
901 : if (!model.input_variables().empty())
902 : {
903 : os << "Output: ";
904 : first = true;
905 : for (auto && [name, var] : model.output_variables())
906 : {
907 : os << (first ? "" : tab);
908 : os << name << " [" << var->type() << "]\n";
909 : first = false;
910 : }
911 : }
912 :
913 : if (!model.named_parameters().empty())
914 : {
915 : os << "Parameters: ";
916 : first = true;
917 : for (auto && [name, param] : model.named_parameters())
918 : {
919 : os << (first ? "" : tab);
920 : os << name << " [" << param->type() << "][" << Tensor(*param).scalar_type() << "]["
921 : << Tensor(*param).device() << "]\n";
922 : first = false;
923 : }
924 : }
925 :
926 : if (!model.named_buffers().empty())
927 : {
928 : os << "Buffers: ";
929 : first = true;
930 : for (auto && [name, buffer] : model.named_buffers())
931 : {
932 : os << (first ? "" : tab);
933 : os << name << " [" << buffer->type() << "][" << Tensor(*buffer).scalar_type() << "]["
934 : << Tensor(*buffer).device() << "]\n";
935 : first = false;
936 : }
937 : }
938 :
939 : return os;
940 : }
941 : // LCOV_EXCL_STOP
942 : } // namespace neml2
|