LCOV - code coverage report
Current view: top level - models - Model.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 72.5 % 465 337
Test Date: 2025-06-29 01:25:44 Functions: 75.8 % 62 47

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include <c10/core/InferenceMode.h>
      26              : 
      27              : #include "neml2/misc/assertions.h"
      28              : #include "neml2/base/Factory.h"
      29              : #include "neml2/base/Settings.h"
      30              : #include "neml2/jit/utils.h"
      31              : #include "neml2/tensors/functions/jacrev.h"
      32              : #include "neml2/tensors/tensors.h"
      33              : #include "neml2/tensors/TensorValue.h"
      34              : #include "neml2/models/Model.h"
      35              : #include "neml2/models/Assembler.h"
      36              : #include "neml2/models/map_types_fwd.h"
      37              : 
      38              : namespace neml2
      39              : {
      40              : std::shared_ptr<Model>
      41           12 : load_model(const std::filesystem::path & path, const std::string & mname)
      42              : {
      43           12 :   auto factory = load_input(path);
      44           24 :   return factory->get_model(mname);
      45           12 : }
      46              : 
      47              : bool
      48            0 : Model::TraceSchema::operator==(const TraceSchema & other) const
      49              : {
      50            0 :   return batch_dims == other.batch_dims && dispatch_key == other.dispatch_key;
      51              : }
      52              : 
      53              : bool
      54        11944 : Model::TraceSchema::operator<(const TraceSchema & other) const
      55              : {
      56        11944 :   if (dispatch_key != other.dispatch_key)
      57            0 :     return dispatch_key < other.dispatch_key;
      58        11944 :   return batch_dims < other.batch_dims;
      59              : }
      60              : 
      61              : OptionSet
      62          463 : Model::expected_options()
      63              : {
      64          463 :   OptionSet options = Data::expected_options();
      65          463 :   options += NonlinearSystem::expected_options();
      66          463 :   NonlinearSystem::disable_automatic_scaling(options);
      67              : 
      68          463 :   options.section() = "Models";
      69              : 
      70              :   // Model defaults to defining value and dvalue, but not d2value
      71          926 :   options.set<bool>("define_values") = true;
      72          926 :   options.set<bool>("define_derivatives") = true;
      73          926 :   options.set<bool>("define_second_derivatives") = false;
      74          926 :   options.set("define_values").suppressed() = true;
      75          926 :   options.set("define_derivatives").suppressed() = true;
      76          926 :   options.set("define_second_derivatives").suppressed() = true;
      77              : 
      78              :   // Model defaults to _not_ being part of a nonlinear system
      79              :   // Model::get_model will set this to true if the model is expected to be part of a nonlinear
      80              :   // system, and additional diagnostics will be performed
      81          926 :   options.set<bool>("_nonlinear_system") = false;
      82          926 :   options.set("_nonlinear_system").suppressed() = true;
      83              : 
      84          926 :   options.set<bool>("jit") = true;
      85          926 :   options.set("jit").doc() = "Use JIT compilation for the forward operator";
      86              : 
      87          926 :   options.set<bool>("production") = false;
      88          926 :   options.set("production").doc() =
      89              :       "Production mode. This option is used to disable features like function graph tracking and "
      90              :       "tensor version tracking which are useful for training (i.e., calibrating model parameters) "
      91          463 :       "but are not necessary in production runs.";
      92              : 
      93          463 :   return options;
      94            0 : }
      95              : 
      96          432 : Model::Model(const OptionSet & options)
      97              :   : Data(options),
      98              :     ParameterStore(this),
      99              :     VariableStore(this),
     100              :     NonlinearSystem(options),
     101              :     DiagnosticsInterface(this),
     102          864 :     _defines_value(options.get<bool>("define_values")),
     103          864 :     _defines_dvalue(options.get<bool>("define_derivatives")),
     104          864 :     _defines_d2value(options.get<bool>("define_second_derivatives")),
     105          432 :     _nonlinear_system(options.get<bool>("_nonlinear_system")),
     106          864 :     _jit(options.get<bool>("jit")),
     107         1728 :     _production(options.get<bool>("production"))
     108              : {
     109          432 : }
     110              : 
     111              : void
     112           91 : Model::to(const TensorOptions & options)
     113              : {
     114           91 :   send_buffers_to(options);
     115           91 :   send_parameters_to(options);
     116           91 :   send_variables_to(options);
     117              : 
     118          174 :   for (auto & submodel : registered_models())
     119           83 :     submodel->to(options);
     120              : 
     121           92 :   for (auto & [name, param] : named_nonlinear_parameters())
     122           92 :     param.provider->to(options);
     123           91 : }
     124              : 
     125              : void
     126          432 : Model::setup()
     127              : {
     128          432 :   setup_layout();
     129              : 
     130          432 :   if (host() == this)
     131              :   {
     132          183 :     link_output_variables();
     133          183 :     link_input_variables();
     134              :   }
     135              : 
     136          432 :   request_AD();
     137          432 : }
     138              : 
     139              : void
     140           96 : Model::diagnose() const
     141              : {
     142          182 :   for (auto & submodel : registered_models())
     143           86 :     neml2::diagnose(*submodel);
     144              : 
     145              :   // Make sure variables are defined on the reserved subaxes
     146          391 :   for (auto && [name, var] : input_variables())
     147          295 :     diagnostic_check_input_variable(*var);
     148          213 :   for (auto && [name, var] : output_variables())
     149          117 :     diagnostic_check_output_variable(*var);
     150              : 
     151           96 :   if (is_nonlinear_system())
     152            5 :     diagnose_nl_sys();
     153           96 : }
     154              : 
     155              : void
     156           72 : Model::diagnose_nl_sys() const
     157              : {
     158          139 :   for (auto & submodel : registered_models())
     159           67 :     submodel->diagnose_nl_sys();
     160              : 
     161              :   // Check if any input variable is solve-dependent
     162           72 :   bool input_solve_dep = false;
     163          253 :   for (auto && [name, var] : input_variables())
     164          181 :     if (var->is_solve_dependent())
     165          109 :       input_solve_dep = true;
     166              : 
     167              :   // If any input variable is solve-dependent, ALL output variables must be solve-dependent!
     168           72 :   if (input_solve_dep)
     169          145 :     for (auto && [name, var] : output_variables())
     170           76 :       diagnostic_assert(
     171           76 :           var->is_solve_dependent(),
     172              :           "This model is part of a nonlinear system. At least one of the input variables is "
     173              :           "solve-dependent, so all output variables MUST be solve-dependent, i.e., they must be "
     174              :           "on one of the following sub-axes: state, residual, parameters. However, got output "
     175              :           "variable ",
     176              :           name);
     177           72 : }
     178              : 
     179              : void
     180           26 : Model::diagnostic_assert_state(const VariableBase & v) const
     181              : {
     182           26 :   diagnostic_assert(v.is_state(), "Variable ", v.name(), " must be on the ", STATE, " sub-axis.");
     183           26 : }
     184              : 
     185              : void
     186            0 : Model::diagnostic_assert_old_state(const VariableBase & v) const
     187              : {
     188            0 :   diagnostic_assert(
     189            0 :       v.is_old_state(), "Variable ", v.name(), " must be on the ", OLD_STATE, " sub-axis.");
     190            0 : }
     191              : 
     192              : void
     193           19 : Model::diagnostic_assert_force(const VariableBase & v) const
     194              : {
     195           19 :   diagnostic_assert(v.is_force(), "Variable ", v.name(), " must be on the ", FORCES, " sub-axis.");
     196           19 : }
     197              : 
     198              : void
     199            0 : Model::diagnostic_assert_old_force(const VariableBase & v) const
     200              : {
     201            0 :   diagnostic_assert(
     202            0 :       v.is_old_force(), "Variable ", v.name(), " must be on the ", OLD_FORCES, " sub-axis.");
     203            0 : }
     204              : 
     205              : void
     206            0 : Model::diagnostic_assert_residual(const VariableBase & v) const
     207              : {
     208            0 :   diagnostic_assert(
     209            0 :       v.is_residual(), "Variable ", v.name(), " must be on the ", RESIDUAL, " sub-axis.");
     210            0 : }
     211              : 
     212              : void
     213          295 : Model::diagnostic_check_input_variable(const VariableBase & v) const
     214              : {
     215          753 :   diagnostic_assert(v.is_state() || v.is_old_state() || v.is_force() || v.is_old_force() ||
     216          458 :                         v.is_residual() || v.is_parameter(),
     217              :                     "Input variable ",
     218          295 :                     v.name(),
     219              :                     " must be on one of the following sub-axes: ",
     220              :                     STATE,
     221              :                     ", ",
     222              :                     OLD_STATE,
     223              :                     ", ",
     224              :                     FORCES,
     225              :                     ", ",
     226              :                     OLD_FORCES,
     227              :                     ", ",
     228              :                     RESIDUAL,
     229              :                     ", ",
     230              :                     PARAMETERS,
     231              :                     ".");
     232          295 : }
     233              : 
     234              : void
     235          117 : Model::diagnostic_check_output_variable(const VariableBase & v) const
     236              : {
     237          117 :   diagnostic_assert(v.is_state() || v.is_force() || v.is_residual() || v.is_parameter(),
     238              :                     "Output variable ",
     239          117 :                     v.name(),
     240              :                     " must be on one of the following sub-axes: ",
     241              :                     STATE,
     242              :                     ", ",
     243              :                     FORCES,
     244              :                     ", ",
     245              :                     RESIDUAL,
     246              :                     ", ",
     247              :                     PARAMETERS,
     248              :                     ".");
     249          117 : }
     250              : 
     251              : void
     252          433 : Model::link_input_variables()
     253              : {
     254          683 :   for (auto & submodel : _registered_models)
     255              :   {
     256          250 :     link_input_variables(submodel.get());
     257          250 :     submodel->link_input_variables();
     258              :   }
     259          433 : }
     260              : 
     261              : void
     262           13 : Model::link_input_variables(Model * submodel)
     263              : {
     264           88 :   for (auto && [name, var] : submodel->input_variables())
     265           75 :     var->ref(input_variable(name), submodel->is_nonlinear_system());
     266           13 : }
     267              : 
     268              : void
     269          433 : Model::link_output_variables()
     270              : {
     271          683 :   for (auto & submodel : _registered_models)
     272              :   {
     273          250 :     link_output_variables(submodel.get());
     274          250 :     submodel->link_output_variables();
     275              :   }
     276          433 : }
     277              : 
     278              : void
     279           13 : Model::link_output_variables(Model * /*submodel*/)
     280              : {
     281           13 : }
     282              : 
     283              : void
     284           14 : Model::request_AD(VariableBase & y, const VariableBase & u)
     285              : {
     286           14 :   neml_assert(_defines_value,
     287              :               "Model of type '",
     288           14 :               type(),
     289              :               "' is requesting automatic differentiation of first derivatives, but it does not "
     290              :               "define output values.");
     291           14 :   _defines_dvalue = true;
     292           14 :   _ad_derivs[&y].insert(&u);
     293              :   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     294           14 :   _ad_args.insert(const_cast<VariableBase *>(&u));
     295           14 : }
     296              : 
     297              : void
     298           48 : Model::request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2)
     299              : {
     300           48 :   neml_assert(_defines_dvalue,
     301              :               "Model of type '",
     302           48 :               type(),
     303              :               "' is requesting automatic differentiation of second derivatives, but it does not "
     304              :               "define first derivatives.");
     305           48 :   _defines_d2value = true;
     306           48 :   _ad_secderivs[&y][&u1].insert(&u2);
     307              :   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     308           48 :   _ad_args.insert(const_cast<VariableBase *>(&u2));
     309           48 : }
     310              : 
     311              : void
     312         8451 : Model::clear_input()
     313              : {
     314         8451 :   VariableStore::clear_input();
     315        13636 :   for (auto & submodel : _registered_models)
     316         5185 :     submodel->clear_input();
     317         8451 : }
     318              : 
     319              : void
     320         8451 : Model::clear_output()
     321              : {
     322         8451 :   VariableStore::clear_output();
     323        13636 :   for (auto & submodel : _registered_models)
     324         5185 :     submodel->clear_output();
     325         8451 : }
     326              : 
     327              : void
     328         8451 : Model::zero_input()
     329              : {
     330         8451 :   VariableStore::zero_input();
     331        13636 :   for (auto & submodel : _registered_models)
     332         5185 :     submodel->zero_input();
     333         8451 : }
     334              : 
     335              : void
     336         8451 : Model::zero_output()
     337              : {
     338         8451 :   VariableStore::zero_output();
     339        13636 :   for (auto & submodel : _registered_models)
     340         5185 :     submodel->zero_output();
     341         8451 : }
     342              : 
     343              : Model::TraceSchema
     344         6366 : Model::compute_trace_schema() const
     345              : {
     346         6366 :   std::vector<Size> batch_dims;
     347        42258 :   for (auto && [name, var] : input_variables())
     348        35892 :     batch_dims.push_back(var->batch_dim());
     349        21861 :   for (auto && [name, param] : host<ParameterStore>()->named_parameters())
     350        15495 :     batch_dims.push_back(Tensor(*param).batch_dim());
     351              : 
     352         6366 :   const auto dispatch_key = variable_options().computeDispatchKey();
     353              : 
     354        12732 :   return TraceSchema{batch_dims, dispatch_key};
     355         6366 : }
     356              : 
     357              : std::size_t
     358         6366 : Model::forward_operator_index(bool out, bool dout, bool d2out) const
     359              : {
     360         6366 :   return (out ? 4 : 0) + (dout ? 2 : 0) + (d2out ? 1 : 0);
     361              : }
     362              : 
     363              : void
     364            0 : Model::register_callback(const ModelCallback & callback)
     365              : {
     366            0 :   _callbacks.push_back(callback);
     367            0 : }
     368              : 
     369              : void
     370            0 : Model::register_callback_recursive(const ModelCallback & callback)
     371              : {
     372            0 :   register_callback(callback);
     373              : 
     374            0 :   for (auto & submodel : registered_models())
     375            0 :     submodel->register_callback_recursive(callback);
     376            0 : }
     377              : 
     378              : void
     379         1113 : Model::forward(bool out, bool dout, bool d2out)
     380              : {
     381         1113 :   neml_assert_dbg(defines_values() || (defines_values() == out),
     382              :                   "Model of type '",
     383         1113 :                   type(),
     384              :                   "' is requested to compute output values, but it does not define them.");
     385         1113 :   neml_assert_dbg(defines_derivatives() || (defines_derivatives() == dout),
     386              :                   "Model of type '",
     387         1113 :                   type(),
     388              :                   "' is requested to compute first derivatives, but it does not define them.");
     389         1113 :   neml_assert_dbg(defines_second_derivatives() || (defines_second_derivatives() == d2out),
     390              :                   "Model of type '",
     391         1113 :                   type(),
     392              :                   "' is requested to compute second derivatives, but it does not define them.");
     393              : 
     394         1113 :   c10::InferenceMode mode_guard(_production && !jit::tracer::isTracing());
     395              : 
     396         1113 :   if (dout || d2out)
     397          491 :     enable_AD();
     398              : 
     399         1113 :   set_value(out || AD_need_value(dout, d2out), dout, d2out);
     400              : 
     401         1113 :   if (dout || d2out)
     402          491 :     extract_AD_derivatives(dout, d2out);
     403              : 
     404              :   // Call the callbacks
     405         1113 :   call_callbacks();
     406              : 
     407         2226 :   return;
     408         1113 : }
     409              : 
     410              : void
     411         7085 : Model::forward_maybe_jit(bool out, bool dout, bool d2out)
     412              : {
     413         7085 :   if (!is_jit_enabled() || jit::tracer::isTracing())
     414              :   {
     415          719 :     forward(out, dout, d2out);
     416          719 :     return;
     417              :   }
     418              : 
     419              :   auto & traced_functions =
     420         6366 :       currently_solving_nonlinear_system() ? _traced_functions_nl_sys : _traced_functions;
     421              : 
     422         6366 :   const auto forward_op_idx = forward_operator_index(out, dout, d2out);
     423         6366 :   const auto new_schema = compute_trace_schema();
     424         6366 :   auto traced_schema_and_function = traced_functions[forward_op_idx].find(new_schema);
     425              : 
     426         6366 :   if (traced_schema_and_function != traced_functions[forward_op_idx].end())
     427              :   {
     428         5972 :     auto & [trace_schema, traced_function] = *traced_schema_and_function;
     429         5972 :     c10::InferenceMode mode_guard(_production);
     430         5972 :     auto stack = collect_input_stack();
     431         5972 :     traced_function->run(stack);
     432         5972 :     assign_output_stack(stack, out, dout, d2out);
     433         5972 :   }
     434              :   else
     435              :   {
     436              :     // All other models in the world should wait for this model to finish tracing
     437              :     // This is not our fault, torch jit tracing is not thread-safe
     438              :     static std::mutex trace_mutex;
     439          394 :     trace_mutex.lock();
     440          394 :     auto forward_wrap = [&](jit::Stack inputs) -> jit::Stack
     441              :     {
     442          394 :       assign_input_stack(inputs);
     443          394 :       forward(out, dout, d2out);
     444          394 :       return collect_output_stack(out, dout, d2out);
     445          394 :     };
     446         1970 :     auto trace = std::get<0>(jit::tracer::trace(
     447          788 :         collect_input_stack(),
     448              :         forward_wrap,
     449        31019 :         [this](const ATensor & var) { return variable_name_lookup(var); },
     450              :         /*strict=*/false,
     451          394 :         /*force_outplace=*/false));
     452          394 :     trace_mutex.unlock();
     453              : 
     454          788 :     auto new_function = std::make_unique<jit::GraphFunction>(name() + ".forward",
     455          394 :                                                              trace->graph,
     456            0 :                                                              /*function_creator=*/nullptr,
     457          788 :                                                              jit::ExecutorExecutionMode::PROFILING);
     458          394 :     traced_functions[forward_op_idx].emplace(new_schema, std::move(new_function));
     459              : 
     460              :     // Rerun this method -- this time using the jitted graph (without tracing)
     461          394 :     forward_maybe_jit(out, dout, d2out);
     462          394 :   }
     463         6366 : }
     464              : 
     465              : std::string
     466       147909 : Model::variable_name_lookup(const ATensor & var)
     467              : {
     468              :   // Look for the variable in the input and output variables
     469       535920 :   for (auto && [ivar, val] : input_variables())
     470       389414 :     if (val->tensor().data_ptr() == var.data_ptr())
     471         1403 :       return name() + "::" + utils::stringify(ivar);
     472       307991 :   for (auto && [ovar, val] : output_variables())
     473       161819 :     if (val->tensor().data_ptr() == var.data_ptr())
     474          334 :       return name() + "::" + utils::stringify(ovar);
     475              : 
     476              :   // Look for the variable in the parameter and buffer store
     477      1069026 :   for (auto && [pname, pval] : host<ParameterStore>()->named_parameters())
     478       923442 :     if (Tensor(*pval).data_ptr() == var.data_ptr())
     479          588 :       return name() + "::" + utils::stringify(pname);
     480       980789 :   for (auto && [bname, bval] : host<BufferStore>()->named_buffers())
     481       835388 :     if (Tensor(*bval).data_ptr() == var.data_ptr())
     482          183 :       return name() + "::" + utils::stringify(bname);
     483              : 
     484              :   // Look for the variable in the registered models
     485       262689 :   for (auto & submodel : registered_models())
     486              :   {
     487       117678 :     auto name = submodel->variable_name_lookup(var);
     488       117678 :     if (!name.empty())
     489          390 :       return name;
     490       117678 :   }
     491              : 
     492       290022 :   return "";
     493              : }
     494              : 
     495              : void
     496         3266 : Model::check_precision() const
     497              : {
     498         3266 :   if (settings().require_double_precision())
     499         3266 :     neml_assert(
     500         6532 :         default_tensor_options().dtype() == kFloat64,
     501              :         "By default, NEML2 requires double precision for all computations. Please set the default "
     502              :         "dtype to Float64. In Python, this can be done by calling "
     503              :         "`torch.set_default_dtype(torch.double)`. In C++, this can be done by calling "
     504              :         "`neml2::set_default_dtype(neml2::kFloat64)`. If other precisions are truly needed, you "
     505              :         "can disable this error check with Settings/require_double_precision=false.");
     506         3266 : }
     507              : 
     508              : ValueMap
     509         2353 : Model::value(const ValueMap & in)
     510              : {
     511         2353 :   forward_helper(in, true, false, false);
     512              : 
     513         2353 :   auto values = collect_output();
     514         2353 :   clear_input();
     515         2353 :   clear_output();
     516         2353 :   return values;
     517            0 : }
     518              : 
     519              : ValueMap
     520            2 : Model::value(ValueMap && in)
     521              : {
     522            2 :   forward_helper(std::move(in), true, false, false);
     523              : 
     524            2 :   auto values = collect_output();
     525            2 :   clear_input();
     526            2 :   clear_output();
     527            2 :   return values;
     528            0 : }
     529              : 
     530              : std::tuple<ValueMap, DerivMap>
     531            0 : Model::value_and_dvalue(const ValueMap & in)
     532              : {
     533            0 :   forward_helper(in, true, true, false);
     534              : 
     535            0 :   const auto values = collect_output();
     536            0 :   const auto derivs = collect_output_derivatives();
     537            0 :   clear_input();
     538            0 :   clear_output();
     539            0 :   return {values, derivs};
     540            0 : }
     541              : 
     542              : std::tuple<ValueMap, DerivMap>
     543            0 : Model::value_and_dvalue(ValueMap && in)
     544              : {
     545            0 :   forward_helper(std::move(in), true, true, false);
     546              : 
     547            0 :   const auto values = collect_output();
     548            0 :   const auto derivs = collect_output_derivatives();
     549            0 :   clear_input();
     550            0 :   clear_output();
     551            0 :   return {values, derivs};
     552            0 : }
     553              : 
     554              : DerivMap
     555          863 : Model::dvalue(const ValueMap & in)
     556              : {
     557          863 :   forward_helper(in, false, true, false);
     558              : 
     559          863 :   auto derivs = collect_output_derivatives();
     560          863 :   clear_input();
     561          863 :   clear_output();
     562          863 :   return derivs;
     563            0 : }
     564              : 
     565              : DerivMap
     566            0 : Model::dvalue(ValueMap && in)
     567              : {
     568            0 :   forward_helper(std::move(in), false, true, false);
     569              : 
     570            0 :   auto derivs = collect_output_derivatives();
     571            0 :   clear_input();
     572            0 :   clear_output();
     573            0 :   return derivs;
     574            0 : }
     575              : 
     576              : std::tuple<ValueMap, DerivMap, SecDerivMap>
     577            0 : Model::value_and_dvalue_and_d2value(const ValueMap & in)
     578              : {
     579            0 :   forward_helper(in, true, true, true);
     580              : 
     581            0 :   const auto values = collect_output();
     582            0 :   const auto derivs = collect_output_derivatives();
     583            0 :   const auto secderivs = collect_output_second_derivatives();
     584            0 :   clear_input();
     585            0 :   clear_output();
     586            0 :   return {values, derivs, secderivs};
     587            0 : }
     588              : 
     589              : std::tuple<ValueMap, DerivMap, SecDerivMap>
     590            0 : Model::value_and_dvalue_and_d2value(ValueMap && in)
     591              : {
     592            0 :   forward_helper(std::move(in), true, true, true);
     593              : 
     594            0 :   const auto values = collect_output();
     595            0 :   const auto derivs = collect_output_derivatives();
     596            0 :   const auto secderivs = collect_output_second_derivatives();
     597            0 :   clear_input();
     598            0 :   clear_output();
     599            0 :   return {values, derivs, secderivs};
     600            0 : }
     601              : 
     602              : std::tuple<DerivMap, SecDerivMap>
     603            0 : Model::dvalue_and_d2value(const ValueMap & in)
     604              : {
     605            0 :   forward_helper(in, false, true, true);
     606              : 
     607            0 :   const auto derivs = collect_output_derivatives();
     608            0 :   const auto secderivs = collect_output_second_derivatives();
     609            0 :   clear_input();
     610            0 :   clear_output();
     611            0 :   return {derivs, secderivs};
     612            0 : }
     613              : 
     614              : std::tuple<DerivMap, SecDerivMap>
     615            0 : Model::dvalue_and_d2value(ValueMap && in)
     616              : {
     617            0 :   forward_helper(std::move(in), false, true, true);
     618              : 
     619            0 :   const auto derivs = collect_output_derivatives();
     620            0 :   const auto secderivs = collect_output_second_derivatives();
     621            0 :   clear_input();
     622            0 :   clear_output();
     623            0 :   return {derivs, secderivs};
     624            0 : }
     625              : 
     626              : SecDerivMap
     627           48 : Model::d2value(const ValueMap & in)
     628              : {
     629           48 :   forward_helper(in, false, false, true);
     630              : 
     631           48 :   auto secderivs = collect_output_second_derivatives();
     632           48 :   clear_input();
     633           48 :   clear_output();
     634           48 :   return secderivs;
     635            0 : }
     636              : 
     637              : SecDerivMap
     638            0 : Model::d2value(ValueMap && in)
     639              : {
     640            0 :   forward_helper(std::move(in), false, false, true);
     641              : 
     642            0 :   auto secderivs = collect_output_second_derivatives();
     643            0 :   clear_input();
     644            0 :   clear_output();
     645            0 :   return secderivs;
     646            0 : }
     647              : 
     648              : std::shared_ptr<Model>
     649          237 : Model::registered_model(const std::string & name) const
     650              : {
     651          820 :   for (auto & submodel : _registered_models)
     652          820 :     if (submodel->name() == name)
     653          237 :       return submodel;
     654              : 
     655            0 :   throw NEMLException("There is no registered model named '" + name + "' in '" + this->name() +
     656            0 :                       "'");
     657              : }
     658              : 
     659              : void
     660           66 : Model::register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param)
     661              : {
     662           66 :   neml_assert(_nl_params.count(pname) == 0,
     663              :               "Nonlinear parameter named '",
     664              :               pname,
     665              :               "' has already been registered.");
     666           66 :   _nl_params[pname] = param;
     667           66 : }
     668              : 
     669              : bool
     670            0 : Model::has_nl_param(bool recursive) const
     671              : {
     672            0 :   if (!recursive)
     673            0 :     return !_nl_params.empty();
     674              : 
     675            0 :   for (auto & submodel : registered_models())
     676            0 :     if (submodel->has_nl_param(true))
     677            0 :       return true;
     678              : 
     679            0 :   return false;
     680              : }
     681              : 
     682              : const VariableBase *
     683          345 : Model::nl_param(const std::string & name) const
     684              : {
     685          345 :   return _nl_params.count(name) ? _nl_params.at(name).value : nullptr;
     686              : }
     687              : 
     688              : std::map<std::string, NonlinearParameter>
     689          547 : Model::named_nonlinear_parameters(bool recursive) const
     690              : {
     691          547 :   if (!recursive)
     692           79 :     return _nl_params;
     693              : 
     694          468 :   auto all_nl_params = _nl_params;
     695              : 
     696          600 :   for (const auto & [pname, param] : _nl_params)
     697          132 :     for (auto && [pname, nl_param] : param.provider->named_nonlinear_parameters(true))
     698            0 :       all_nl_params[param.provider->name() + settings().parameter_name_separator() + pname] =
     699          132 :           nl_param;
     700              : 
     701          482 :   for (auto & submodel : registered_models())
     702           14 :     for (auto && [pname, nl_param] : submodel->named_nonlinear_parameters(true))
     703           14 :       all_nl_params[submodel->name() + settings().parameter_name_separator() + pname] = nl_param;
     704              : 
     705          468 :   return all_nl_params;
     706          468 : }
     707              : 
     708              : std::set<VariableName>
     709          237 : Model::consumed_items() const
     710              : {
     711          237 :   auto items = input_axis().variable_names();
     712          474 :   return {items.begin(), items.end()};
     713          237 : }
     714              : 
     715              : std::set<VariableName>
     716          237 : Model::provided_items() const
     717              : {
     718          237 :   auto items = output_axis().variable_names();
     719          474 :   return {items.begin(), items.end()};
     720          237 : }
     721              : 
     722              : void
     723          394 : Model::assign_input_stack(jit::Stack & stack)
     724              : {
     725              : #ifndef NDEBUG
     726          394 :   const auto nstack = input_axis().nvariable() + host<ParameterStore>()->named_parameters().size();
     727          394 :   neml_assert_dbg(
     728          394 :       stack.size() == nstack,
     729              :       "Stack size (",
     730          394 :       stack.size(),
     731              :       ") must equal to the number of input variables, parameters, and buffers in the model (",
     732              :       nstack,
     733              :       ").");
     734              : #endif
     735              : 
     736          394 :   assign_parameter_stack(stack);
     737          394 :   VariableStore::assign_input_stack(stack);
     738          394 : }
     739              : 
     740              : jit::Stack
     741         6366 : Model::collect_input_stack() const
     742              : {
     743         6366 :   auto stack = VariableStore::collect_input_stack();
     744         6366 :   const auto param_stack = collect_parameter_stack();
     745              : 
     746              :   // Recall stack is first in last out.
     747              :   // Parameter stack go after (on top of) input variables. This means that in assign_input_stack
     748              :   // we need to pop parameters first, then input variables.
     749         6366 :   stack.insert(stack.end(), param_stack.begin(), param_stack.end());
     750        12732 :   return stack;
     751         6366 : }
     752              : 
     753              : void
     754         1563 : Model::set_guess(const Sol<false> & x)
     755              : {
     756         1563 :   const auto sol_assember = VectorAssembler(input_axis().subaxis(STATE));
     757         1563 :   assign_input(sol_assember.split_by_variable(x));
     758         1563 : }
     759              : 
     760              : void
     761         2807 : Model::assemble(NonlinearSystem::Res<false> * residual, NonlinearSystem::Jac<false> * Jacobian)
     762              : {
     763         2807 :   forward_maybe_jit(residual, Jacobian, false);
     764              : 
     765         2807 :   if (residual)
     766              :   {
     767         1563 :     const auto res_assembler = VectorAssembler(output_axis().subaxis(RESIDUAL));
     768         1563 :     *residual = Res<false>(res_assembler.assemble_by_variable(collect_output()));
     769              :   }
     770         2807 :   if (Jacobian)
     771              :   {
     772              :     const auto jac_assembler =
     773         1244 :         MatrixAssembler(output_axis().subaxis(RESIDUAL), input_axis().subaxis(STATE));
     774         1244 :     *Jacobian = Jac<false>(jac_assembler.assemble_by_variable(collect_output_derivatives()));
     775              :   }
     776         2807 : }
     777              : 
     778              : bool
     779          229 : Model::AD_need_value(bool dout, bool d2out) const
     780              : {
     781          229 :   if (dout)
     782          180 :     if (!_ad_derivs.empty())
     783            2 :       return true;
     784              : 
     785          227 :   if (d2out)
     786           52 :     for (auto && [y, u1u2s] : _ad_secderivs)
     787            0 :       for (auto && [u1, u2s] : u1u2s)
     788            0 :         if (_ad_derivs.count(y) && _ad_derivs.at(y).count(u1))
     789            0 :           return true;
     790              : 
     791          227 :   return false;
     792              : }
     793              : 
     794              : void
     795          491 : Model::enable_AD()
     796              : {
     797          497 :   for (auto * ad_arg : _ad_args)
     798            6 :     ad_arg->requires_grad_();
     799          491 : }
     800              : 
     801              : void
     802          491 : Model::extract_AD_derivatives(bool dout, bool d2out)
     803              : {
     804          491 :   neml_assert(dout || d2out, "At least one of the output derivatives must be requested.");
     805              : 
     806          495 :   for (auto && [y, us] : _ad_derivs)
     807              :   {
     808            4 :     if (!dout && d2out)
     809            0 :       if (!_ad_secderivs.count(y))
     810            0 :         continue;
     811              : 
     812              :     // Gather all dependent variables
     813            4 :     std::vector<Tensor> uts;
     814           18 :     for (const auto * u : us)
     815           14 :       if (u->is_dependent())
     816           14 :         uts.push_back(u->tensor());
     817              : 
     818              :     // Check if we need to create the graph (i.e., if any of the second derivatives are requested)
     819            4 :     bool create_graph = false;
     820           18 :     for (const auto * u : us)
     821           14 :       if (u->is_dependent())
     822           14 :         if (!create_graph && !dout && d2out)
     823            0 :           if (_ad_secderivs.at(y).count(u))
     824            0 :             create_graph = true;
     825              : 
     826            4 :     const auto dy_dus = jacrev(y->tensor(),
     827              :                                uts,
     828              :                                /*retain_graph=*/true,
     829              :                                /*create_graph=*/create_graph,
     830            4 :                                /*allow_unused=*/true);
     831              : 
     832            4 :     std::size_t i = 0;
     833           18 :     for (const auto * u : us)
     834           14 :       if (u->is_dependent())
     835              :       {
     836           14 :         if (dy_dus[i].defined())
     837           14 :           y->d(*u) = dy_dus[i];
     838           14 :         i++;
     839              :       }
     840            4 :   }
     841              : 
     842          491 :   if (d2out)
     843              :   {
     844          100 :     for (auto && [y, u1u2s] : _ad_secderivs)
     845            0 :       for (auto && [u1, u2s] : u1u2s)
     846              :       {
     847            0 :         if (!u1->is_dependent())
     848            0 :           continue;
     849              : 
     850            0 :         const auto & dy_du1 = y->derivatives()[u1->name()];
     851              : 
     852            0 :         if (!dy_du1.defined() || !dy_du1.requires_grad())
     853            0 :           continue;
     854              : 
     855            0 :         std::vector<Tensor> u2ts;
     856            0 :         for (const auto * u2 : u2s)
     857            0 :           if (u2->is_dependent())
     858            0 :             u2ts.push_back(u2->tensor());
     859              : 
     860              :         const auto d2y_du1u2s = jacrev(dy_du1,
     861              :                                        u2ts,
     862              :                                        /*retain_graph=*/true,
     863              :                                        /*create_graph=*/false,
     864            0 :                                        /*allow_unused=*/true);
     865              : 
     866            0 :         std::size_t i = 0;
     867            0 :         for (const auto * u2 : u2s)
     868            0 :           if (u2->is_dependent())
     869              :           {
     870            0 :             if (d2y_du1u2s[i].defined())
     871            0 :               y->d(*u1, *u2) = d2y_du1u2s[i];
     872            0 :             i++;
     873              :           }
     874            0 :       }
     875              :   }
     876          491 : }
     877              : 
     878              : // LCOV_EXCL_START
     879              : std::ostream &
     880              : operator<<(std::ostream & os, const Model & model)
     881              : {
     882              :   bool first = false;
     883              :   const std::string tab = "            ";
     884              : 
     885              :   os << "Name:       " << model.name() << '\n';
     886              : 
     887              :   if (!model.input_variables().empty())
     888              :   {
     889              :     os << "Input:      ";
     890              :     first = true;
     891              :     for (auto && [name, var] : model.input_variables())
     892              :     {
     893              :       os << (first ? "" : tab);
     894              :       os << name << " [" << var->type() << "]\n";
     895              :       first = false;
     896              :     }
     897              :   }
     898              : 
     899              :   if (!model.input_variables().empty())
     900              :   {
     901              :     os << "Output:     ";
     902              :     first = true;
     903              :     for (auto && [name, var] : model.output_variables())
     904              :     {
     905              :       os << (first ? "" : tab);
     906              :       os << name << " [" << var->type() << "]\n";
     907              :       first = false;
     908              :     }
     909              :   }
     910              : 
     911              :   if (!model.named_parameters().empty())
     912              :   {
     913              :     os << "Parameters: ";
     914              :     first = true;
     915              :     for (auto && [name, param] : model.named_parameters())
     916              :     {
     917              :       os << (first ? "" : tab);
     918              :       os << name << " [" << param->type() << "][" << Tensor(*param).scalar_type() << "]["
     919              :          << Tensor(*param).device() << "]\n";
     920              :       first = false;
     921              :     }
     922              :   }
     923              : 
     924              :   if (!model.named_buffers().empty())
     925              :   {
     926              :     os << "Buffers:    ";
     927              :     first = true;
     928              :     for (auto && [name, buffer] : model.named_buffers())
     929              :     {
     930              :       os << (first ? "" : tab);
     931              :       os << name << " [" << buffer->type() << "][" << Tensor(*buffer).scalar_type() << "]["
     932              :          << Tensor(*buffer).device() << "]\n";
     933              :       first = false;
     934              :     }
     935              :   }
     936              : 
     937              :   return os;
     938              : }
     939              : 
     940              : void
     941              : Model::call_callbacks() const
     942              : {
     943              :   for (const auto & callback : _callbacks)
     944              :     callback(*this, input_variables(), output_variables());
     945              : }
     946              : 
     947              : // LCOV_EXCL_STOP
     948              : } // namespace neml2
        

Generated by: LCOV version 2.0-1