Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include <c10/core/InferenceMode.h>
26 :
27 : #include "neml2/misc/assertions.h"
28 : #include "neml2/base/Factory.h"
29 : #include "neml2/base/Settings.h"
30 : #include "neml2/jit/utils.h"
31 : #include "neml2/tensors/functions/jacrev.h"
32 : #include "neml2/tensors/tensors.h"
33 : #include "neml2/tensors/TensorValue.h"
34 : #include "neml2/models/Model.h"
35 : #include "neml2/models/Assembler.h"
36 : #include "neml2/models/map_types_fwd.h"
37 :
38 : namespace neml2
39 : {
40 : std::shared_ptr<Model>
41 12 : load_model(const std::filesystem::path & path, const std::string & mname)
42 : {
43 12 : auto factory = load_input(path);
44 24 : return factory->get_model(mname);
45 12 : }
46 :
47 : bool
48 0 : Model::TraceSchema::operator==(const TraceSchema & other) const
49 : {
50 0 : return batch_dims == other.batch_dims && dispatch_key == other.dispatch_key;
51 : }
52 :
53 : bool
54 11944 : Model::TraceSchema::operator<(const TraceSchema & other) const
55 : {
56 11944 : if (dispatch_key != other.dispatch_key)
57 0 : return dispatch_key < other.dispatch_key;
58 11944 : return batch_dims < other.batch_dims;
59 : }
60 :
61 : OptionSet
62 463 : Model::expected_options()
63 : {
64 463 : OptionSet options = Data::expected_options();
65 463 : options += NonlinearSystem::expected_options();
66 463 : NonlinearSystem::disable_automatic_scaling(options);
67 :
68 463 : options.section() = "Models";
69 :
70 : // Model defaults to defining value and dvalue, but not d2value
71 926 : options.set<bool>("define_values") = true;
72 926 : options.set<bool>("define_derivatives") = true;
73 926 : options.set<bool>("define_second_derivatives") = false;
74 926 : options.set("define_values").suppressed() = true;
75 926 : options.set("define_derivatives").suppressed() = true;
76 926 : options.set("define_second_derivatives").suppressed() = true;
77 :
78 : // Model defaults to _not_ being part of a nonlinear system
79 : // Model::get_model will set this to true if the model is expected to be part of a nonlinear
80 : // system, and additional diagnostics will be performed
81 926 : options.set<bool>("_nonlinear_system") = false;
82 926 : options.set("_nonlinear_system").suppressed() = true;
83 :
84 926 : options.set<bool>("jit") = true;
85 926 : options.set("jit").doc() = "Use JIT compilation for the forward operator";
86 :
87 926 : options.set<bool>("production") = false;
88 926 : options.set("production").doc() =
89 : "Production mode. This option is used to disable features like function graph tracking and "
90 : "tensor version tracking which are useful for training (i.e., calibrating model parameters) "
91 463 : "but are not necessary in production runs.";
92 :
93 463 : return options;
94 0 : }
95 :
96 432 : Model::Model(const OptionSet & options)
97 : : Data(options),
98 : ParameterStore(this),
99 : VariableStore(this),
100 : NonlinearSystem(options),
101 : DiagnosticsInterface(this),
102 864 : _defines_value(options.get<bool>("define_values")),
103 864 : _defines_dvalue(options.get<bool>("define_derivatives")),
104 864 : _defines_d2value(options.get<bool>("define_second_derivatives")),
105 432 : _nonlinear_system(options.get<bool>("_nonlinear_system")),
106 864 : _jit(options.get<bool>("jit")),
107 1728 : _production(options.get<bool>("production"))
108 : {
109 432 : }
110 :
111 : void
112 91 : Model::to(const TensorOptions & options)
113 : {
114 91 : send_buffers_to(options);
115 91 : send_parameters_to(options);
116 91 : send_variables_to(options);
117 :
118 174 : for (auto & submodel : registered_models())
119 83 : submodel->to(options);
120 :
121 92 : for (auto & [name, param] : named_nonlinear_parameters())
122 92 : param.provider->to(options);
123 91 : }
124 :
125 : void
126 432 : Model::setup()
127 : {
128 432 : setup_layout();
129 :
130 432 : if (host() == this)
131 : {
132 183 : link_output_variables();
133 183 : link_input_variables();
134 : }
135 :
136 432 : request_AD();
137 432 : }
138 :
139 : void
140 96 : Model::diagnose() const
141 : {
142 182 : for (auto & submodel : registered_models())
143 86 : neml2::diagnose(*submodel);
144 :
145 : // Make sure variables are defined on the reserved subaxes
146 391 : for (auto && [name, var] : input_variables())
147 295 : diagnostic_check_input_variable(*var);
148 213 : for (auto && [name, var] : output_variables())
149 117 : diagnostic_check_output_variable(*var);
150 :
151 96 : if (is_nonlinear_system())
152 5 : diagnose_nl_sys();
153 96 : }
154 :
155 : void
156 72 : Model::diagnose_nl_sys() const
157 : {
158 139 : for (auto & submodel : registered_models())
159 67 : submodel->diagnose_nl_sys();
160 :
161 : // Check if any input variable is solve-dependent
162 72 : bool input_solve_dep = false;
163 253 : for (auto && [name, var] : input_variables())
164 181 : if (var->is_solve_dependent())
165 109 : input_solve_dep = true;
166 :
167 : // If any input variable is solve-dependent, ALL output variables must be solve-dependent!
168 72 : if (input_solve_dep)
169 145 : for (auto && [name, var] : output_variables())
170 76 : diagnostic_assert(
171 76 : var->is_solve_dependent(),
172 : "This model is part of a nonlinear system. At least one of the input variables is "
173 : "solve-dependent, so all output variables MUST be solve-dependent, i.e., they must be "
174 : "on one of the following sub-axes: state, residual, parameters. However, got output "
175 : "variable ",
176 : name);
177 72 : }
178 :
179 : void
180 26 : Model::diagnostic_assert_state(const VariableBase & v) const
181 : {
182 26 : diagnostic_assert(v.is_state(), "Variable ", v.name(), " must be on the ", STATE, " sub-axis.");
183 26 : }
184 :
185 : void
186 0 : Model::diagnostic_assert_old_state(const VariableBase & v) const
187 : {
188 0 : diagnostic_assert(
189 0 : v.is_old_state(), "Variable ", v.name(), " must be on the ", OLD_STATE, " sub-axis.");
190 0 : }
191 :
192 : void
193 19 : Model::diagnostic_assert_force(const VariableBase & v) const
194 : {
195 19 : diagnostic_assert(v.is_force(), "Variable ", v.name(), " must be on the ", FORCES, " sub-axis.");
196 19 : }
197 :
198 : void
199 0 : Model::diagnostic_assert_old_force(const VariableBase & v) const
200 : {
201 0 : diagnostic_assert(
202 0 : v.is_old_force(), "Variable ", v.name(), " must be on the ", OLD_FORCES, " sub-axis.");
203 0 : }
204 :
205 : void
206 0 : Model::diagnostic_assert_residual(const VariableBase & v) const
207 : {
208 0 : diagnostic_assert(
209 0 : v.is_residual(), "Variable ", v.name(), " must be on the ", RESIDUAL, " sub-axis.");
210 0 : }
211 :
212 : void
213 295 : Model::diagnostic_check_input_variable(const VariableBase & v) const
214 : {
215 753 : diagnostic_assert(v.is_state() || v.is_old_state() || v.is_force() || v.is_old_force() ||
216 458 : v.is_residual() || v.is_parameter(),
217 : "Input variable ",
218 295 : v.name(),
219 : " must be on one of the following sub-axes: ",
220 : STATE,
221 : ", ",
222 : OLD_STATE,
223 : ", ",
224 : FORCES,
225 : ", ",
226 : OLD_FORCES,
227 : ", ",
228 : RESIDUAL,
229 : ", ",
230 : PARAMETERS,
231 : ".");
232 295 : }
233 :
234 : void
235 117 : Model::diagnostic_check_output_variable(const VariableBase & v) const
236 : {
237 117 : diagnostic_assert(v.is_state() || v.is_force() || v.is_residual() || v.is_parameter(),
238 : "Output variable ",
239 117 : v.name(),
240 : " must be on one of the following sub-axes: ",
241 : STATE,
242 : ", ",
243 : FORCES,
244 : ", ",
245 : RESIDUAL,
246 : ", ",
247 : PARAMETERS,
248 : ".");
249 117 : }
250 :
251 : void
252 433 : Model::link_input_variables()
253 : {
254 683 : for (auto & submodel : _registered_models)
255 : {
256 250 : link_input_variables(submodel.get());
257 250 : submodel->link_input_variables();
258 : }
259 433 : }
260 :
261 : void
262 13 : Model::link_input_variables(Model * submodel)
263 : {
264 88 : for (auto && [name, var] : submodel->input_variables())
265 75 : var->ref(input_variable(name), submodel->is_nonlinear_system());
266 13 : }
267 :
268 : void
269 433 : Model::link_output_variables()
270 : {
271 683 : for (auto & submodel : _registered_models)
272 : {
273 250 : link_output_variables(submodel.get());
274 250 : submodel->link_output_variables();
275 : }
276 433 : }
277 :
278 : void
279 13 : Model::link_output_variables(Model * /*submodel*/)
280 : {
281 13 : }
282 :
283 : void
284 14 : Model::request_AD(VariableBase & y, const VariableBase & u)
285 : {
286 14 : neml_assert(_defines_value,
287 : "Model of type '",
288 14 : type(),
289 : "' is requesting automatic differentiation of first derivatives, but it does not "
290 : "define output values.");
291 14 : _defines_dvalue = true;
292 14 : _ad_derivs[&y].insert(&u);
293 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
294 14 : _ad_args.insert(const_cast<VariableBase *>(&u));
295 14 : }
296 :
297 : void
298 48 : Model::request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2)
299 : {
300 48 : neml_assert(_defines_dvalue,
301 : "Model of type '",
302 48 : type(),
303 : "' is requesting automatic differentiation of second derivatives, but it does not "
304 : "define first derivatives.");
305 48 : _defines_d2value = true;
306 48 : _ad_secderivs[&y][&u1].insert(&u2);
307 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
308 48 : _ad_args.insert(const_cast<VariableBase *>(&u2));
309 48 : }
310 :
311 : void
312 8451 : Model::clear_input()
313 : {
314 8451 : VariableStore::clear_input();
315 13636 : for (auto & submodel : _registered_models)
316 5185 : submodel->clear_input();
317 8451 : }
318 :
319 : void
320 8451 : Model::clear_output()
321 : {
322 8451 : VariableStore::clear_output();
323 13636 : for (auto & submodel : _registered_models)
324 5185 : submodel->clear_output();
325 8451 : }
326 :
327 : void
328 8451 : Model::zero_input()
329 : {
330 8451 : VariableStore::zero_input();
331 13636 : for (auto & submodel : _registered_models)
332 5185 : submodel->zero_input();
333 8451 : }
334 :
335 : void
336 8451 : Model::zero_output()
337 : {
338 8451 : VariableStore::zero_output();
339 13636 : for (auto & submodel : _registered_models)
340 5185 : submodel->zero_output();
341 8451 : }
342 :
343 : Model::TraceSchema
344 6366 : Model::compute_trace_schema() const
345 : {
346 6366 : std::vector<Size> batch_dims;
347 42258 : for (auto && [name, var] : input_variables())
348 35892 : batch_dims.push_back(var->batch_dim());
349 21861 : for (auto && [name, param] : host<ParameterStore>()->named_parameters())
350 15495 : batch_dims.push_back(Tensor(*param).batch_dim());
351 :
352 6366 : const auto dispatch_key = variable_options().computeDispatchKey();
353 :
354 12732 : return TraceSchema{batch_dims, dispatch_key};
355 6366 : }
356 :
357 : std::size_t
358 6366 : Model::forward_operator_index(bool out, bool dout, bool d2out) const
359 : {
360 6366 : return (out ? 4 : 0) + (dout ? 2 : 0) + (d2out ? 1 : 0);
361 : }
362 :
363 : void
364 0 : Model::register_callback(const ModelCallback & callback)
365 : {
366 0 : _callbacks.push_back(callback);
367 0 : }
368 :
369 : void
370 0 : Model::register_callback_recursive(const ModelCallback & callback)
371 : {
372 0 : register_callback(callback);
373 :
374 0 : for (auto & submodel : registered_models())
375 0 : submodel->register_callback_recursive(callback);
376 0 : }
377 :
378 : void
379 1113 : Model::forward(bool out, bool dout, bool d2out)
380 : {
381 1113 : neml_assert_dbg(defines_values() || (defines_values() == out),
382 : "Model of type '",
383 1113 : type(),
384 : "' is requested to compute output values, but it does not define them.");
385 1113 : neml_assert_dbg(defines_derivatives() || (defines_derivatives() == dout),
386 : "Model of type '",
387 1113 : type(),
388 : "' is requested to compute first derivatives, but it does not define them.");
389 1113 : neml_assert_dbg(defines_second_derivatives() || (defines_second_derivatives() == d2out),
390 : "Model of type '",
391 1113 : type(),
392 : "' is requested to compute second derivatives, but it does not define them.");
393 :
394 1113 : c10::InferenceMode mode_guard(_production && !jit::tracer::isTracing());
395 :
396 1113 : if (dout || d2out)
397 491 : enable_AD();
398 :
399 1113 : set_value(out || AD_need_value(dout, d2out), dout, d2out);
400 :
401 1113 : if (dout || d2out)
402 491 : extract_AD_derivatives(dout, d2out);
403 :
404 : // Call the callbacks
405 1113 : call_callbacks();
406 :
407 2226 : return;
408 1113 : }
409 :
410 : void
411 7085 : Model::forward_maybe_jit(bool out, bool dout, bool d2out)
412 : {
413 7085 : if (!is_jit_enabled() || jit::tracer::isTracing())
414 : {
415 719 : forward(out, dout, d2out);
416 719 : return;
417 : }
418 :
419 : auto & traced_functions =
420 6366 : currently_solving_nonlinear_system() ? _traced_functions_nl_sys : _traced_functions;
421 :
422 6366 : const auto forward_op_idx = forward_operator_index(out, dout, d2out);
423 6366 : const auto new_schema = compute_trace_schema();
424 6366 : auto traced_schema_and_function = traced_functions[forward_op_idx].find(new_schema);
425 :
426 6366 : if (traced_schema_and_function != traced_functions[forward_op_idx].end())
427 : {
428 5972 : auto & [trace_schema, traced_function] = *traced_schema_and_function;
429 5972 : c10::InferenceMode mode_guard(_production);
430 5972 : auto stack = collect_input_stack();
431 5972 : traced_function->run(stack);
432 5972 : assign_output_stack(stack, out, dout, d2out);
433 5972 : }
434 : else
435 : {
436 : // All other models in the world should wait for this model to finish tracing
437 : // This is not our fault, torch jit tracing is not thread-safe
438 : static std::mutex trace_mutex;
439 394 : trace_mutex.lock();
440 394 : auto forward_wrap = [&](jit::Stack inputs) -> jit::Stack
441 : {
442 394 : assign_input_stack(inputs);
443 394 : forward(out, dout, d2out);
444 394 : return collect_output_stack(out, dout, d2out);
445 394 : };
446 1970 : auto trace = std::get<0>(jit::tracer::trace(
447 788 : collect_input_stack(),
448 : forward_wrap,
449 31019 : [this](const ATensor & var) { return variable_name_lookup(var); },
450 : /*strict=*/false,
451 394 : /*force_outplace=*/false));
452 394 : trace_mutex.unlock();
453 :
454 788 : auto new_function = std::make_unique<jit::GraphFunction>(name() + ".forward",
455 394 : trace->graph,
456 0 : /*function_creator=*/nullptr,
457 788 : jit::ExecutorExecutionMode::PROFILING);
458 394 : traced_functions[forward_op_idx].emplace(new_schema, std::move(new_function));
459 :
460 : // Rerun this method -- this time using the jitted graph (without tracing)
461 394 : forward_maybe_jit(out, dout, d2out);
462 394 : }
463 6366 : }
464 :
465 : std::string
466 147909 : Model::variable_name_lookup(const ATensor & var)
467 : {
468 : // Look for the variable in the input and output variables
469 535920 : for (auto && [ivar, val] : input_variables())
470 389414 : if (val->tensor().data_ptr() == var.data_ptr())
471 1403 : return name() + "::" + utils::stringify(ivar);
472 307991 : for (auto && [ovar, val] : output_variables())
473 161819 : if (val->tensor().data_ptr() == var.data_ptr())
474 334 : return name() + "::" + utils::stringify(ovar);
475 :
476 : // Look for the variable in the parameter and buffer store
477 1069026 : for (auto && [pname, pval] : host<ParameterStore>()->named_parameters())
478 923442 : if (Tensor(*pval).data_ptr() == var.data_ptr())
479 588 : return name() + "::" + utils::stringify(pname);
480 980789 : for (auto && [bname, bval] : host<BufferStore>()->named_buffers())
481 835388 : if (Tensor(*bval).data_ptr() == var.data_ptr())
482 183 : return name() + "::" + utils::stringify(bname);
483 :
484 : // Look for the variable in the registered models
485 262689 : for (auto & submodel : registered_models())
486 : {
487 117678 : auto name = submodel->variable_name_lookup(var);
488 117678 : if (!name.empty())
489 390 : return name;
490 117678 : }
491 :
492 290022 : return "";
493 : }
494 :
495 : void
496 3266 : Model::check_precision() const
497 : {
498 3266 : if (settings().require_double_precision())
499 3266 : neml_assert(
500 6532 : default_tensor_options().dtype() == kFloat64,
501 : "By default, NEML2 requires double precision for all computations. Please set the default "
502 : "dtype to Float64. In Python, this can be done by calling "
503 : "`torch.set_default_dtype(torch.double)`. In C++, this can be done by calling "
504 : "`neml2::set_default_dtype(neml2::kFloat64)`. If other precisions are truly needed, you "
505 : "can disable this error check with Settings/require_double_precision=false.");
506 3266 : }
507 :
508 : ValueMap
509 2353 : Model::value(const ValueMap & in)
510 : {
511 2353 : forward_helper(in, true, false, false);
512 :
513 2353 : auto values = collect_output();
514 2353 : clear_input();
515 2353 : clear_output();
516 2353 : return values;
517 0 : }
518 :
519 : ValueMap
520 2 : Model::value(ValueMap && in)
521 : {
522 2 : forward_helper(std::move(in), true, false, false);
523 :
524 2 : auto values = collect_output();
525 2 : clear_input();
526 2 : clear_output();
527 2 : return values;
528 0 : }
529 :
530 : std::tuple<ValueMap, DerivMap>
531 0 : Model::value_and_dvalue(const ValueMap & in)
532 : {
533 0 : forward_helper(in, true, true, false);
534 :
535 0 : const auto values = collect_output();
536 0 : const auto derivs = collect_output_derivatives();
537 0 : clear_input();
538 0 : clear_output();
539 0 : return {values, derivs};
540 0 : }
541 :
542 : std::tuple<ValueMap, DerivMap>
543 0 : Model::value_and_dvalue(ValueMap && in)
544 : {
545 0 : forward_helper(std::move(in), true, true, false);
546 :
547 0 : const auto values = collect_output();
548 0 : const auto derivs = collect_output_derivatives();
549 0 : clear_input();
550 0 : clear_output();
551 0 : return {values, derivs};
552 0 : }
553 :
554 : DerivMap
555 863 : Model::dvalue(const ValueMap & in)
556 : {
557 863 : forward_helper(in, false, true, false);
558 :
559 863 : auto derivs = collect_output_derivatives();
560 863 : clear_input();
561 863 : clear_output();
562 863 : return derivs;
563 0 : }
564 :
565 : DerivMap
566 0 : Model::dvalue(ValueMap && in)
567 : {
568 0 : forward_helper(std::move(in), false, true, false);
569 :
570 0 : auto derivs = collect_output_derivatives();
571 0 : clear_input();
572 0 : clear_output();
573 0 : return derivs;
574 0 : }
575 :
576 : std::tuple<ValueMap, DerivMap, SecDerivMap>
577 0 : Model::value_and_dvalue_and_d2value(const ValueMap & in)
578 : {
579 0 : forward_helper(in, true, true, true);
580 :
581 0 : const auto values = collect_output();
582 0 : const auto derivs = collect_output_derivatives();
583 0 : const auto secderivs = collect_output_second_derivatives();
584 0 : clear_input();
585 0 : clear_output();
586 0 : return {values, derivs, secderivs};
587 0 : }
588 :
589 : std::tuple<ValueMap, DerivMap, SecDerivMap>
590 0 : Model::value_and_dvalue_and_d2value(ValueMap && in)
591 : {
592 0 : forward_helper(std::move(in), true, true, true);
593 :
594 0 : const auto values = collect_output();
595 0 : const auto derivs = collect_output_derivatives();
596 0 : const auto secderivs = collect_output_second_derivatives();
597 0 : clear_input();
598 0 : clear_output();
599 0 : return {values, derivs, secderivs};
600 0 : }
601 :
602 : std::tuple<DerivMap, SecDerivMap>
603 0 : Model::dvalue_and_d2value(const ValueMap & in)
604 : {
605 0 : forward_helper(in, false, true, true);
606 :
607 0 : const auto derivs = collect_output_derivatives();
608 0 : const auto secderivs = collect_output_second_derivatives();
609 0 : clear_input();
610 0 : clear_output();
611 0 : return {derivs, secderivs};
612 0 : }
613 :
614 : std::tuple<DerivMap, SecDerivMap>
615 0 : Model::dvalue_and_d2value(ValueMap && in)
616 : {
617 0 : forward_helper(std::move(in), false, true, true);
618 :
619 0 : const auto derivs = collect_output_derivatives();
620 0 : const auto secderivs = collect_output_second_derivatives();
621 0 : clear_input();
622 0 : clear_output();
623 0 : return {derivs, secderivs};
624 0 : }
625 :
626 : SecDerivMap
627 48 : Model::d2value(const ValueMap & in)
628 : {
629 48 : forward_helper(in, false, false, true);
630 :
631 48 : auto secderivs = collect_output_second_derivatives();
632 48 : clear_input();
633 48 : clear_output();
634 48 : return secderivs;
635 0 : }
636 :
637 : SecDerivMap
638 0 : Model::d2value(ValueMap && in)
639 : {
640 0 : forward_helper(std::move(in), false, false, true);
641 :
642 0 : auto secderivs = collect_output_second_derivatives();
643 0 : clear_input();
644 0 : clear_output();
645 0 : return secderivs;
646 0 : }
647 :
648 : std::shared_ptr<Model>
649 237 : Model::registered_model(const std::string & name) const
650 : {
651 820 : for (auto & submodel : _registered_models)
652 820 : if (submodel->name() == name)
653 237 : return submodel;
654 :
655 0 : throw NEMLException("There is no registered model named '" + name + "' in '" + this->name() +
656 0 : "'");
657 : }
658 :
659 : void
660 66 : Model::register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param)
661 : {
662 66 : neml_assert(_nl_params.count(pname) == 0,
663 : "Nonlinear parameter named '",
664 : pname,
665 : "' has already been registered.");
666 66 : _nl_params[pname] = param;
667 66 : }
668 :
669 : bool
670 0 : Model::has_nl_param(bool recursive) const
671 : {
672 0 : if (!recursive)
673 0 : return !_nl_params.empty();
674 :
675 0 : for (auto & submodel : registered_models())
676 0 : if (submodel->has_nl_param(true))
677 0 : return true;
678 :
679 0 : return false;
680 : }
681 :
682 : const VariableBase *
683 345 : Model::nl_param(const std::string & name) const
684 : {
685 345 : return _nl_params.count(name) ? _nl_params.at(name).value : nullptr;
686 : }
687 :
688 : std::map<std::string, NonlinearParameter>
689 547 : Model::named_nonlinear_parameters(bool recursive) const
690 : {
691 547 : if (!recursive)
692 79 : return _nl_params;
693 :
694 468 : auto all_nl_params = _nl_params;
695 :
696 600 : for (const auto & [pname, param] : _nl_params)
697 132 : for (auto && [pname, nl_param] : param.provider->named_nonlinear_parameters(true))
698 0 : all_nl_params[param.provider->name() + settings().parameter_name_separator() + pname] =
699 132 : nl_param;
700 :
701 482 : for (auto & submodel : registered_models())
702 14 : for (auto && [pname, nl_param] : submodel->named_nonlinear_parameters(true))
703 14 : all_nl_params[submodel->name() + settings().parameter_name_separator() + pname] = nl_param;
704 :
705 468 : return all_nl_params;
706 468 : }
707 :
708 : std::set<VariableName>
709 237 : Model::consumed_items() const
710 : {
711 237 : auto items = input_axis().variable_names();
712 474 : return {items.begin(), items.end()};
713 237 : }
714 :
715 : std::set<VariableName>
716 237 : Model::provided_items() const
717 : {
718 237 : auto items = output_axis().variable_names();
719 474 : return {items.begin(), items.end()};
720 237 : }
721 :
722 : void
723 394 : Model::assign_input_stack(jit::Stack & stack)
724 : {
725 : #ifndef NDEBUG
726 394 : const auto nstack = input_axis().nvariable() + host<ParameterStore>()->named_parameters().size();
727 394 : neml_assert_dbg(
728 394 : stack.size() == nstack,
729 : "Stack size (",
730 394 : stack.size(),
731 : ") must equal to the number of input variables, parameters, and buffers in the model (",
732 : nstack,
733 : ").");
734 : #endif
735 :
736 394 : assign_parameter_stack(stack);
737 394 : VariableStore::assign_input_stack(stack);
738 394 : }
739 :
740 : jit::Stack
741 6366 : Model::collect_input_stack() const
742 : {
743 6366 : auto stack = VariableStore::collect_input_stack();
744 6366 : const auto param_stack = collect_parameter_stack();
745 :
746 : // Recall stack is first in last out.
747 : // Parameter stack go after (on top of) input variables. This means that in assign_input_stack
748 : // we need to pop parameters first, then input variables.
749 6366 : stack.insert(stack.end(), param_stack.begin(), param_stack.end());
750 12732 : return stack;
751 6366 : }
752 :
753 : void
754 1563 : Model::set_guess(const Sol<false> & x)
755 : {
756 1563 : const auto sol_assember = VectorAssembler(input_axis().subaxis(STATE));
757 1563 : assign_input(sol_assember.split_by_variable(x));
758 1563 : }
759 :
760 : void
761 2807 : Model::assemble(NonlinearSystem::Res<false> * residual, NonlinearSystem::Jac<false> * Jacobian)
762 : {
763 2807 : forward_maybe_jit(residual, Jacobian, false);
764 :
765 2807 : if (residual)
766 : {
767 1563 : const auto res_assembler = VectorAssembler(output_axis().subaxis(RESIDUAL));
768 1563 : *residual = Res<false>(res_assembler.assemble_by_variable(collect_output()));
769 : }
770 2807 : if (Jacobian)
771 : {
772 : const auto jac_assembler =
773 1244 : MatrixAssembler(output_axis().subaxis(RESIDUAL), input_axis().subaxis(STATE));
774 1244 : *Jacobian = Jac<false>(jac_assembler.assemble_by_variable(collect_output_derivatives()));
775 : }
776 2807 : }
777 :
778 : bool
779 229 : Model::AD_need_value(bool dout, bool d2out) const
780 : {
781 229 : if (dout)
782 180 : if (!_ad_derivs.empty())
783 2 : return true;
784 :
785 227 : if (d2out)
786 52 : for (auto && [y, u1u2s] : _ad_secderivs)
787 0 : for (auto && [u1, u2s] : u1u2s)
788 0 : if (_ad_derivs.count(y) && _ad_derivs.at(y).count(u1))
789 0 : return true;
790 :
791 227 : return false;
792 : }
793 :
794 : void
795 491 : Model::enable_AD()
796 : {
797 497 : for (auto * ad_arg : _ad_args)
798 6 : ad_arg->requires_grad_();
799 491 : }
800 :
801 : void
802 491 : Model::extract_AD_derivatives(bool dout, bool d2out)
803 : {
804 491 : neml_assert(dout || d2out, "At least one of the output derivatives must be requested.");
805 :
806 495 : for (auto && [y, us] : _ad_derivs)
807 : {
808 4 : if (!dout && d2out)
809 0 : if (!_ad_secderivs.count(y))
810 0 : continue;
811 :
812 : // Gather all dependent variables
813 4 : std::vector<Tensor> uts;
814 18 : for (const auto * u : us)
815 14 : if (u->is_dependent())
816 14 : uts.push_back(u->tensor());
817 :
818 : // Check if we need to create the graph (i.e., if any of the second derivatives are requested)
819 4 : bool create_graph = false;
820 18 : for (const auto * u : us)
821 14 : if (u->is_dependent())
822 14 : if (!create_graph && !dout && d2out)
823 0 : if (_ad_secderivs.at(y).count(u))
824 0 : create_graph = true;
825 :
826 4 : const auto dy_dus = jacrev(y->tensor(),
827 : uts,
828 : /*retain_graph=*/true,
829 : /*create_graph=*/create_graph,
830 4 : /*allow_unused=*/true);
831 :
832 4 : std::size_t i = 0;
833 18 : for (const auto * u : us)
834 14 : if (u->is_dependent())
835 : {
836 14 : if (dy_dus[i].defined())
837 14 : y->d(*u) = dy_dus[i];
838 14 : i++;
839 : }
840 4 : }
841 :
842 491 : if (d2out)
843 : {
844 100 : for (auto && [y, u1u2s] : _ad_secderivs)
845 0 : for (auto && [u1, u2s] : u1u2s)
846 : {
847 0 : if (!u1->is_dependent())
848 0 : continue;
849 :
850 0 : const auto & dy_du1 = y->derivatives()[u1->name()];
851 :
852 0 : if (!dy_du1.defined() || !dy_du1.requires_grad())
853 0 : continue;
854 :
855 0 : std::vector<Tensor> u2ts;
856 0 : for (const auto * u2 : u2s)
857 0 : if (u2->is_dependent())
858 0 : u2ts.push_back(u2->tensor());
859 :
860 : const auto d2y_du1u2s = jacrev(dy_du1,
861 : u2ts,
862 : /*retain_graph=*/true,
863 : /*create_graph=*/false,
864 0 : /*allow_unused=*/true);
865 :
866 0 : std::size_t i = 0;
867 0 : for (const auto * u2 : u2s)
868 0 : if (u2->is_dependent())
869 : {
870 0 : if (d2y_du1u2s[i].defined())
871 0 : y->d(*u1, *u2) = d2y_du1u2s[i];
872 0 : i++;
873 : }
874 0 : }
875 : }
876 491 : }
877 :
878 : // LCOV_EXCL_START
879 : std::ostream &
880 : operator<<(std::ostream & os, const Model & model)
881 : {
882 : bool first = false;
883 : const std::string tab = " ";
884 :
885 : os << "Name: " << model.name() << '\n';
886 :
887 : if (!model.input_variables().empty())
888 : {
889 : os << "Input: ";
890 : first = true;
891 : for (auto && [name, var] : model.input_variables())
892 : {
893 : os << (first ? "" : tab);
894 : os << name << " [" << var->type() << "]\n";
895 : first = false;
896 : }
897 : }
898 :
899 : if (!model.input_variables().empty())
900 : {
901 : os << "Output: ";
902 : first = true;
903 : for (auto && [name, var] : model.output_variables())
904 : {
905 : os << (first ? "" : tab);
906 : os << name << " [" << var->type() << "]\n";
907 : first = false;
908 : }
909 : }
910 :
911 : if (!model.named_parameters().empty())
912 : {
913 : os << "Parameters: ";
914 : first = true;
915 : for (auto && [name, param] : model.named_parameters())
916 : {
917 : os << (first ? "" : tab);
918 : os << name << " [" << param->type() << "][" << Tensor(*param).scalar_type() << "]["
919 : << Tensor(*param).device() << "]\n";
920 : first = false;
921 : }
922 : }
923 :
924 : if (!model.named_buffers().empty())
925 : {
926 : os << "Buffers: ";
927 : first = true;
928 : for (auto && [name, buffer] : model.named_buffers())
929 : {
930 : os << (first ? "" : tab);
931 : os << name << " [" << buffer->type() << "][" << Tensor(*buffer).scalar_type() << "]["
932 : << Tensor(*buffer).device() << "]\n";
933 : first = false;
934 : }
935 : }
936 :
937 : return os;
938 : }
939 :
940 : void
941 : Model::call_callbacks() const
942 : {
943 : for (const auto & callback : _callbacks)
944 : callback(*this, input_variables(), output_variables());
945 : }
946 :
947 : // LCOV_EXCL_STOP
948 : } // namespace neml2
|