LCOV - code coverage report
Current view: top level - models - Model.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 73.4 % 466 342
Test Date: 2025-10-02 16:03:03 Functions: 78.3 % 60 47

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include <c10/core/InferenceMode.h>
      26              : #include <torch/csrc/jit/frontend/tracer.h>
      27              : 
      28              : #include "neml2/misc/assertions.h"
      29              : #include "neml2/base/guards.h"
      30              : #include "neml2/base/Factory.h"
      31              : #include "neml2/base/Settings.h"
      32              : #include "neml2/jit/utils.h"
      33              : #include "neml2/tensors/functions/jacrev.h"
      34              : #include "neml2/tensors/tensors.h"
      35              : #include "neml2/tensors/TensorValue.h"
      36              : #include "neml2/models/Model.h"
      37              : #include "neml2/models/Assembler.h"
      38              : #include "neml2/models/map_types_fwd.h"
      39              : 
      40              : namespace neml2
      41              : {
      42              : std::shared_ptr<Model>
      43           15 : load_model(const std::filesystem::path & path, const std::string & mname)
      44              : {
      45           15 :   auto factory = load_input(path);
      46           30 :   return factory->get_model(mname);
      47           15 : }
      48              : 
      49              : bool
      50            0 : Model::TraceSchema::operator==(const TraceSchema & other) const
      51              : {
      52            0 :   return batch_dims == other.batch_dims && dispatch_key == other.dispatch_key;
      53              : }
      54              : 
      55              : bool
      56        13028 : Model::TraceSchema::operator<(const TraceSchema & other) const
      57              : {
      58        13028 :   if (dispatch_key != other.dispatch_key)
      59            0 :     return dispatch_key < other.dispatch_key;
      60        13028 :   return batch_dims < other.batch_dims;
      61              : }
      62              : 
      63              : OptionSet
      64          481 : Model::expected_options()
      65              : {
      66          481 :   OptionSet options = Data::expected_options();
      67          481 :   options += NonlinearSystem::expected_options();
      68          481 :   NonlinearSystem::disable_automatic_scaling(options);
      69              : 
      70          481 :   options.section() = "Models";
      71              : 
      72              :   // Model defaults to defining value and dvalue, but not d2value
      73          962 :   options.set<bool>("define_values") = true;
      74          962 :   options.set<bool>("define_derivatives") = true;
      75          962 :   options.set<bool>("define_second_derivatives") = false;
      76          962 :   options.set("define_values").suppressed() = true;
      77          962 :   options.set("define_derivatives").suppressed() = true;
      78          962 :   options.set("define_second_derivatives").suppressed() = true;
      79              : 
      80              :   // Model defaults to _not_ being part of a nonlinear system
      81              :   // Model::get_model will set this to true if the model is expected to be part of a nonlinear
      82              :   // system, and additional diagnostics will be performed
      83          962 :   options.set<bool>("_nonlinear_system") = false;
      84          962 :   options.set("_nonlinear_system").suppressed() = true;
      85              : 
      86          962 :   options.set<bool>("jit") = true;
      87          962 :   options.set("jit").doc() = "Use JIT compilation for the forward operator";
      88              : 
      89          962 :   options.set<bool>("production") = false;
      90          962 :   options.set("production").doc() =
      91              :       "Production mode. This option is used to disable features like function graph tracking and "
      92              :       "tensor version tracking which are useful for training (i.e., calibrating model parameters) "
      93          481 :       "but are not necessary in production runs.";
      94              : 
      95          481 :   return options;
      96            0 : }
      97              : 
      98          481 : Model::Model(const OptionSet & options)
      99              :   : Data(options),
     100              :     ParameterStore(this),
     101              :     VariableStore(this),
     102              :     NonlinearSystem(options),
     103              :     DiagnosticsInterface(this),
     104          962 :     _defines_value(options.get<bool>("define_values")),
     105          962 :     _defines_dvalue(options.get<bool>("define_derivatives")),
     106          962 :     _defines_d2value(options.get<bool>("define_second_derivatives")),
     107          481 :     _nonlinear_system(options.get<bool>("_nonlinear_system")),
     108         1431 :     _jit(settings().disable_jit() ? false : options.get<bool>("jit")),
     109         1924 :     _production(options.get<bool>("production"))
     110              : {
     111          481 : }
     112              : 
     113              : void
     114           91 : Model::to(const TensorOptions & options)
     115              : {
     116           91 :   send_buffers_to(options);
     117           91 :   send_parameters_to(options);
     118           91 :   send_variables_to(options);
     119              : 
     120          174 :   for (auto & submodel : registered_models())
     121           83 :     submodel->to(options);
     122              : 
     123           92 :   for (auto & [name, param] : named_nonlinear_parameters())
     124           92 :     param.provider->to(options);
     125           91 : }
     126              : 
     127              : void
     128          481 : Model::setup()
     129              : {
     130          481 :   setup_layout();
     131              : 
     132          481 :   if (host() == this)
     133              :   {
     134          208 :     link_output_variables();
     135          208 :     link_input_variables();
     136              :   }
     137              : 
     138          481 :   request_AD();
     139          481 : }
     140              : 
     141              : void
     142           99 : Model::diagnose() const
     143              : {
     144          187 :   for (auto & submodel : registered_models())
     145           88 :     neml2::diagnose(*submodel);
     146              : 
     147              :   // Make sure variables are defined on the reserved subaxes
     148          401 :   for (auto && [name, var] : input_variables())
     149          302 :     diagnostic_check_input_variable(*var);
     150          219 :   for (auto && [name, var] : output_variables())
     151          120 :     diagnostic_check_output_variable(*var);
     152              : 
     153           99 :   if (is_nonlinear_system())
     154            5 :     diagnose_nl_sys();
     155              : 
     156           99 :   if (settings().disable_jit())
     157            9 :     if (input_options().user_specified("jit"))
     158            3 :       diagnostic_assert(!input_options().get<bool>("jit"),
     159              :                         "JIT compilation is disabled globally by Settings/disable_jit=true, and it "
     160              :                         "is an error to explicitly set jit to true for any model.");
     161           99 : }
     162              : 
     163              : void
     164           72 : Model::diagnose_nl_sys() const
     165              : {
     166          139 :   for (auto & submodel : registered_models())
     167           67 :     submodel->diagnose_nl_sys();
     168              : 
     169              :   // Check if any input variable is solve-dependent
     170           72 :   bool input_solve_dep = false;
     171          253 :   for (auto && [name, var] : input_variables())
     172          181 :     if (var->is_solve_dependent())
     173          109 :       input_solve_dep = true;
     174              : 
     175              :   // If any input variable is solve-dependent, ALL output variables must be solve-dependent!
     176           72 :   if (input_solve_dep)
     177          145 :     for (auto && [name, var] : output_variables())
     178           76 :       diagnostic_assert(
     179           76 :           var->is_solve_dependent(),
     180              :           "This model is part of a nonlinear system. At least one of the input variables is "
     181              :           "solve-dependent, so all output variables MUST be solve-dependent, i.e., they must be "
     182              :           "on one of the following sub-axes: state, residual, parameters. However, got output "
     183              :           "variable ",
     184              :           name);
     185           72 : }
     186              : 
     187              : void
     188           26 : Model::diagnostic_assert_state(const VariableBase & v) const
     189              : {
     190           26 :   diagnostic_assert(v.is_state(), "Variable ", v.name(), " must be on the ", STATE, " sub-axis.");
     191           26 : }
     192              : 
     193              : void
     194            0 : Model::diagnostic_assert_old_state(const VariableBase & v) const
     195              : {
     196            0 :   diagnostic_assert(
     197            0 :       v.is_old_state(), "Variable ", v.name(), " must be on the ", OLD_STATE, " sub-axis.");
     198            0 : }
     199              : 
     200              : void
     201           19 : Model::diagnostic_assert_force(const VariableBase & v) const
     202              : {
     203           19 :   diagnostic_assert(v.is_force(), "Variable ", v.name(), " must be on the ", FORCES, " sub-axis.");
     204           19 : }
     205              : 
     206              : void
     207            0 : Model::diagnostic_assert_old_force(const VariableBase & v) const
     208              : {
     209            0 :   diagnostic_assert(
     210            0 :       v.is_old_force(), "Variable ", v.name(), " must be on the ", OLD_FORCES, " sub-axis.");
     211            0 : }
     212              : 
     213              : void
     214            0 : Model::diagnostic_assert_residual(const VariableBase & v) const
     215              : {
     216            0 :   diagnostic_assert(
     217            0 :       v.is_residual(), "Variable ", v.name(), " must be on the ", RESIDUAL, " sub-axis.");
     218            0 : }
     219              : 
     220              : void
     221          302 : Model::diagnostic_check_input_variable(const VariableBase & v) const
     222              : {
     223          770 :   diagnostic_assert(v.is_state() || v.is_old_state() || v.is_force() || v.is_old_force() ||
     224          468 :                         v.is_residual() || v.is_parameter(),
     225              :                     "Input variable ",
     226          302 :                     v.name(),
     227              :                     " must be on one of the following sub-axes: ",
     228              :                     STATE,
     229              :                     ", ",
     230              :                     OLD_STATE,
     231              :                     ", ",
     232              :                     FORCES,
     233              :                     ", ",
     234              :                     OLD_FORCES,
     235              :                     ", ",
     236              :                     RESIDUAL,
     237              :                     ", ",
     238              :                     PARAMETERS,
     239              :                     ".");
     240          302 : }
     241              : 
     242              : void
     243          120 : Model::diagnostic_check_output_variable(const VariableBase & v) const
     244              : {
     245          120 :   diagnostic_assert(v.is_state() || v.is_force() || v.is_residual() || v.is_parameter(),
     246              :                     "Output variable ",
     247          120 :                     v.name(),
     248              :                     " must be on one of the following sub-axes: ",
     249              :                     STATE,
     250              :                     ", ",
     251              :                     FORCES,
     252              :                     ", ",
     253              :                     RESIDUAL,
     254              :                     ", ",
     255              :                     PARAMETERS,
     256              :                     ".");
     257          120 : }
     258              : 
     259              : void
     260          482 : Model::link_input_variables()
     261              : {
     262          756 :   for (auto & submodel : _registered_models)
     263              :   {
     264          274 :     link_input_variables(submodel.get());
     265          274 :     submodel->link_input_variables();
     266              :   }
     267          482 : }
     268              : 
     269              : void
     270           14 : Model::link_input_variables(Model * submodel)
     271              : {
     272           98 :   for (auto && [name, var] : submodel->input_variables())
     273           84 :     var->ref(input_variable(name), submodel->is_nonlinear_system());
     274           14 : }
     275              : 
     276              : void
     277          482 : Model::link_output_variables()
     278              : {
     279          756 :   for (auto & submodel : _registered_models)
     280              :   {
     281          274 :     link_output_variables(submodel.get());
     282          274 :     submodel->link_output_variables();
     283              :   }
     284          482 : }
     285              : 
     286              : void
     287           14 : Model::link_output_variables(Model * /*submodel*/)
     288              : {
     289           14 : }
     290              : 
     291              : void
     292           14 : Model::request_AD(VariableBase & y, const VariableBase & u)
     293              : {
     294           14 :   neml_assert(_defines_value,
     295              :               "Model of type '",
     296           14 :               type(),
     297              :               "' is requesting automatic differentiation of first derivatives, but it does not "
     298              :               "define output values.");
     299           14 :   _defines_dvalue = true;
     300           14 :   _ad_derivs[&y].insert(&u);
     301              :   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     302           14 :   _ad_args.insert(const_cast<VariableBase *>(&u));
     303           14 : }
     304              : 
     305              : void
     306           48 : Model::request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2)
     307              : {
     308           48 :   neml_assert(_defines_dvalue,
     309              :               "Model of type '",
     310           48 :               type(),
     311              :               "' is requesting automatic differentiation of second derivatives, but it does not "
     312              :               "define first derivatives.");
     313           48 :   _defines_d2value = true;
     314           48 :   _ad_secderivs[&y][&u1].insert(&u2);
     315              :   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     316           48 :   _ad_args.insert(const_cast<VariableBase *>(&u2));
     317           48 : }
     318              : 
     319              : void
     320         9219 : Model::clear_input()
     321              : {
     322         9219 :   VariableStore::clear_input();
     323        14646 :   for (auto & submodel : _registered_models)
     324         5427 :     submodel->clear_input();
     325         9219 : }
     326              : 
     327              : void
     328         9219 : Model::clear_output()
     329              : {
     330         9219 :   VariableStore::clear_output();
     331        14646 :   for (auto & submodel : _registered_models)
     332         5427 :     submodel->clear_output();
     333         9219 : }
     334              : 
     335              : void
     336         9221 : Model::zero_input()
     337              : {
     338         9221 :   VariableStore::zero_input();
     339        14648 :   for (auto & submodel : _registered_models)
     340         5427 :     submodel->zero_input();
     341         9221 : }
     342              : 
     343              : void
     344         9221 : Model::zero_output()
     345              : {
     346         9221 :   VariableStore::zero_output();
     347        14648 :   for (auto & submodel : _registered_models)
     348         5427 :     submodel->zero_output();
     349         9221 : }
     350              : 
     351              : Model::TraceSchema
     352         6968 : Model::compute_trace_schema() const
     353              : {
     354         6968 :   std::vector<Size> batch_dims;
     355        43921 :   for (auto && [name, var] : input_variables())
     356        36953 :     batch_dims.push_back(var->batch_dim());
     357        23609 :   for (auto && [name, param] : host<ParameterStore>()->named_parameters())
     358        16641 :     batch_dims.push_back(Tensor(*param).batch_dim());
     359              : 
     360         6968 :   const auto dispatch_key = variable_options().computeDispatchKey();
     361              : 
     362        13936 :   return TraceSchema{batch_dims, dispatch_key};
     363         6968 : }
     364              : 
     365              : std::size_t
     366         6968 : Model::forward_operator_index(bool out, bool dout, bool d2out) const
     367              : {
     368         6968 :   return (out ? 4 : 0) + (dout ? 2 : 0) + (d2out ? 1 : 0);
     369              : }
     370              : 
     371              : void
     372         1213 : Model::forward(bool out, bool dout, bool d2out)
     373              : {
     374         1213 :   neml_assert_dbg(defines_values() || (defines_values() == out),
     375              :                   "Model of type '",
     376         1213 :                   type(),
     377              :                   "' is requested to compute output values, but it does not define them.");
     378         1213 :   neml_assert_dbg(defines_derivatives() || (defines_derivatives() == dout),
     379              :                   "Model of type '",
     380         1213 :                   type(),
     381              :                   "' is requested to compute first derivatives, but it does not define them.");
     382         1213 :   neml_assert_dbg(defines_second_derivatives() || (defines_second_derivatives() == d2out),
     383              :                   "Model of type '",
     384         1213 :                   type(),
     385              :                   "' is requested to compute second derivatives, but it does not define them.");
     386              : 
     387         1213 :   if (dout || d2out)
     388          547 :     clear_derivatives();
     389              : 
     390         1213 :   c10::InferenceMode mode_guard(_production && !jit::tracer::isTracing());
     391              : 
     392         1213 :   if (dout || d2out)
     393          547 :     enable_AD();
     394              : 
     395         1213 :   set_value(out || AD_need_value(dout, d2out), dout, d2out);
     396              : 
     397         1213 :   if (dout || d2out)
     398          547 :     extract_AD_derivatives(dout, d2out);
     399              : 
     400         2426 :   return;
     401         1213 : }
     402              : 
     403              : void
     404         7727 : Model::forward_maybe_jit(bool out, bool dout, bool d2out)
     405              : {
     406         7727 :   if (!is_jit_enabled() || jit::tracer::isTracing())
     407              :   {
     408          759 :     forward(out, dout, d2out);
     409          759 :     return;
     410              :   }
     411              : 
     412              :   auto & traced_functions =
     413         6968 :       currently_solving_nonlinear_system() ? _traced_functions_nl_sys : _traced_functions;
     414              : 
     415         6968 :   const auto forward_op_idx = forward_operator_index(out, dout, d2out);
     416         6968 :   const auto new_schema = compute_trace_schema();
     417         6968 :   auto traced_schema_and_function = traced_functions[forward_op_idx].find(new_schema);
     418              : 
     419         6968 :   if (traced_schema_and_function != traced_functions[forward_op_idx].end())
     420              :   {
     421         6514 :     auto & [trace_schema, traced_function] = *traced_schema_and_function;
     422         6514 :     c10::InferenceMode mode_guard(_production);
     423         6514 :     auto stack = collect_input_stack();
     424         6514 :     traced_function->run(stack);
     425         6512 :     assign_output_stack(stack, out, dout, d2out);
     426         6516 :   }
     427              :   else
     428              :   {
     429              :     // All other models in the world should wait for this model to finish tracing
     430              :     // This is not our fault, torch jit tracing is not thread-safe
     431          454 :     std::shared_ptr<jit::tracer::TracingState> trace;
     432              :     static std::mutex trace_mutex;
     433          454 :     trace_mutex.lock();
     434              :     try
     435              :     {
     436          454 :       auto forward_wrap = [&](jit::Stack inputs) -> jit::Stack
     437              :       {
     438          454 :         assign_input_stack(inputs);
     439          454 :         forward(out, dout, d2out);
     440          454 :         return collect_output_stack(out, dout, d2out);
     441          454 :       };
     442         2270 :       trace = std::get<0>(jit::tracer::trace(
     443          908 :           collect_input_stack(),
     444              :           forward_wrap,
     445        42801 :           [this](const ATensor & var) { return variable_name_lookup(var); },
     446              :           /*strict=*/false,
     447          454 :           /*force_outplace=*/false));
     448              :     }
     449            0 :     catch (const std::exception & e)
     450              :     {
     451            0 :       trace_mutex.unlock();
     452            0 :       throw NEMLException("Failed to trace model '" + name() + "': " + e.what());
     453            0 :     }
     454          454 :     trace_mutex.unlock();
     455              : 
     456          908 :     auto new_function = std::make_unique<jit::GraphFunction>(name() + ".forward",
     457          454 :                                                              trace->graph,
     458            0 :                                                              /*function_creator=*/nullptr,
     459          908 :                                                              jit::ExecutorExecutionMode::PROFILING);
     460          454 :     traced_functions[forward_op_idx].emplace(new_schema, std::move(new_function));
     461              : 
     462              :     // Rerun this method -- this time using the jitted graph (without tracing)
     463          454 :     forward_maybe_jit(out, dout, d2out);
     464          454 :   }
     465         6968 : }
     466              : 
     467              : std::string
     468       160272 : Model::variable_name_lookup(const ATensor & var)
     469              : {
     470              :   // Look for the variable in the input and output variables
     471       566448 :   for (auto && [ivar, val] : input_variables())
     472       407830 :     if (val->tensor().data_ptr() == var.data_ptr())
     473         1654 :       return name() + "::" + utils::stringify(ivar);
     474       340697 :   for (auto && [ovar, val] : output_variables())
     475       182420 :     if (val->tensor().data_ptr() == var.data_ptr())
     476          341 :       return name() + "::" + utils::stringify(ovar);
     477              : 
     478              :   // Look for the variable in the parameter and buffer store
     479      1101077 :   for (auto && [pname, pval] : host<ParameterStore>()->named_parameters())
     480       943646 :     if (Tensor(*pval).data_ptr() == var.data_ptr())
     481          846 :       return name() + "::" + utils::stringify(pname);
     482       858975 :   for (auto && [bname, bval] : host<BufferStore>()->named_buffers())
     483       701705 :     if (Tensor(*bval).data_ptr() == var.data_ptr())
     484          161 :       return name() + "::" + utils::stringify(bname);
     485              : 
     486              :   // Look for the variable in the registered models
     487       275253 :   for (auto & submodel : registered_models())
     488              :   {
     489       118379 :     auto name = submodel->variable_name_lookup(var);
     490       118379 :     if (!name.empty())
     491          396 :       return name;
     492       118379 :   }
     493              : 
     494       313748 :   return "";
     495              : }
     496              : 
     497              : void
     498         3794 : Model::check_precision() const
     499              : {
     500         3794 :   if (settings().require_double_precision())
     501         3794 :     neml_assert(
     502         7588 :         default_tensor_options().dtype() == kFloat64,
     503              :         "By default, NEML2 requires double precision for all computations. Please set the default "
     504              :         "dtype to Float64. In Python, this can be done by calling "
     505              :         "`torch.set_default_dtype(torch.double)`. In C++, this can be done by calling "
     506              :         "`neml2::set_default_dtype(neml2::kFloat64)`. If other precisions are truly needed, you "
     507              :         "can disable this error check with Settings/require_double_precision=false.");
     508         3794 : }
     509              : 
     510              : ValueMap
     511         2678 : Model::value(const ValueMap & in)
     512              : {
     513         2678 :   forward_helper(in, true, false, false);
     514              : 
     515         2676 :   auto values = collect_output();
     516         2676 :   clear_input();
     517         2676 :   clear_output();
     518         2676 :   return values;
     519            0 : }
     520              : 
     521              : ValueMap
     522            2 : Model::value(ValueMap && in)
     523              : {
     524            2 :   forward_helper(std::move(in), true, false, false);
     525              : 
     526            2 :   auto values = collect_output();
     527            2 :   clear_input();
     528            2 :   clear_output();
     529            2 :   return values;
     530            0 : }
     531              : 
     532              : std::tuple<ValueMap, DerivMap>
     533            0 : Model::value_and_dvalue(const ValueMap & in)
     534              : {
     535            0 :   forward_helper(in, true, true, false);
     536              : 
     537            0 :   const auto values = collect_output();
     538            0 :   const auto derivs = collect_output_derivatives();
     539            0 :   clear_input();
     540            0 :   clear_output();
     541            0 :   return {values, derivs};
     542            0 : }
     543              : 
     544              : std::tuple<ValueMap, DerivMap>
     545            0 : Model::value_and_dvalue(ValueMap && in)
     546              : {
     547            0 :   forward_helper(std::move(in), true, true, false);
     548              : 
     549            0 :   const auto values = collect_output();
     550            0 :   const auto derivs = collect_output_derivatives();
     551            0 :   clear_input();
     552            0 :   clear_output();
     553            0 :   return {values, derivs};
     554            0 : }
     555              : 
     556              : DerivMap
     557         1052 : Model::dvalue(const ValueMap & in)
     558              : {
     559         1052 :   forward_helper(in, false, true, false);
     560              : 
     561         1052 :   auto derivs = collect_output_derivatives();
     562         1052 :   clear_input();
     563         1052 :   clear_output();
     564         1052 :   return derivs;
     565            0 : }
     566              : 
     567              : DerivMap
     568            0 : Model::dvalue(ValueMap && in)
     569              : {
     570            0 :   forward_helper(std::move(in), false, true, false);
     571              : 
     572            0 :   auto derivs = collect_output_derivatives();
     573            0 :   clear_input();
     574            0 :   clear_output();
     575            0 :   return derivs;
     576            0 : }
     577              : 
     578              : std::tuple<ValueMap, DerivMap, SecDerivMap>
     579            0 : Model::value_and_dvalue_and_d2value(const ValueMap & in)
     580              : {
     581            0 :   forward_helper(in, true, true, true);
     582              : 
     583            0 :   const auto values = collect_output();
     584            0 :   const auto derivs = collect_output_derivatives();
     585            0 :   const auto secderivs = collect_output_second_derivatives();
     586            0 :   clear_input();
     587            0 :   clear_output();
     588            0 :   return {values, derivs, secderivs};
     589            0 : }
     590              : 
     591              : std::tuple<ValueMap, DerivMap, SecDerivMap>
     592            0 : Model::value_and_dvalue_and_d2value(ValueMap && in)
     593              : {
     594            0 :   forward_helper(std::move(in), true, true, true);
     595              : 
     596            0 :   const auto values = collect_output();
     597            0 :   const auto derivs = collect_output_derivatives();
     598            0 :   const auto secderivs = collect_output_second_derivatives();
     599            0 :   clear_input();
     600            0 :   clear_output();
     601            0 :   return {values, derivs, secderivs};
     602            0 : }
     603              : 
     604              : std::tuple<DerivMap, SecDerivMap>
     605            0 : Model::dvalue_and_d2value(const ValueMap & in)
     606              : {
     607            0 :   forward_helper(in, false, true, true);
     608              : 
     609            0 :   const auto derivs = collect_output_derivatives();
     610            0 :   const auto secderivs = collect_output_second_derivatives();
     611            0 :   clear_input();
     612            0 :   clear_output();
     613            0 :   return {derivs, secderivs};
     614            0 : }
     615              : 
     616              : std::tuple<DerivMap, SecDerivMap>
     617            0 : Model::dvalue_and_d2value(ValueMap && in)
     618              : {
     619            0 :   forward_helper(std::move(in), false, true, true);
     620              : 
     621            0 :   const auto derivs = collect_output_derivatives();
     622            0 :   const auto secderivs = collect_output_second_derivatives();
     623            0 :   clear_input();
     624            0 :   clear_output();
     625            0 :   return {derivs, secderivs};
     626            0 : }
     627              : 
     628              : SecDerivMap
     629           62 : Model::d2value(const ValueMap & in)
     630              : {
     631           62 :   forward_helper(in, false, false, true);
     632              : 
     633           62 :   auto secderivs = collect_output_second_derivatives();
     634           62 :   clear_input();
     635           62 :   clear_output();
     636           62 :   return secderivs;
     637            0 : }
     638              : 
     639              : SecDerivMap
     640            0 : Model::d2value(ValueMap && in)
     641              : {
     642            0 :   forward_helper(std::move(in), false, false, true);
     643              : 
     644            0 :   auto secderivs = collect_output_second_derivatives();
     645            0 :   clear_input();
     646            0 :   clear_output();
     647            0 :   return secderivs;
     648            0 : }
     649              : 
     650              : std::shared_ptr<Model>
     651          260 : Model::registered_model(const std::string & name) const
     652              : {
     653          861 :   for (auto & submodel : _registered_models)
     654          861 :     if (submodel->name() == name)
     655          260 :       return submodel;
     656              : 
     657            0 :   throw NEMLException("There is no registered model named '" + name + "' in '" + this->name() +
     658            0 :                       "'");
     659              : }
     660              : 
     661              : void
     662           78 : Model::register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param)
     663              : {
     664           78 :   neml_assert(_nl_params.count(pname) == 0,
     665              :               "Nonlinear parameter named '",
     666              :               pname,
     667              :               "' has already been registered.");
     668           78 :   _nl_params[pname] = param;
     669           78 : }
     670              : 
     671              : bool
     672            0 : Model::has_nl_param(bool recursive) const
     673              : {
     674            0 :   if (!recursive)
     675            0 :     return !_nl_params.empty();
     676              : 
     677            0 :   for (auto & submodel : registered_models())
     678            0 :     if (submodel->has_nl_param(true))
     679            0 :       return true;
     680              : 
     681            0 :   return false;
     682              : }
     683              : 
     684              : const VariableBase *
     685          385 : Model::nl_param(const std::string & name) const
     686              : {
     687          385 :   return _nl_params.count(name) ? _nl_params.at(name).value : nullptr;
     688              : }
     689              : 
     690              : std::map<std::string, NonlinearParameter>
     691          597 : Model::named_nonlinear_parameters(bool recursive) const
     692              : {
     693          597 :   if (!recursive)
     694           79 :     return _nl_params;
     695              : 
     696          518 :   auto all_nl_params = _nl_params;
     697              : 
     698          674 :   for (const auto & [pname, param] : _nl_params)
     699          156 :     for (auto && [pname, nl_param] : param.provider->named_nonlinear_parameters(true))
     700            0 :       all_nl_params[param.provider->name() + settings().parameter_name_separator() + pname] =
     701          156 :           nl_param;
     702              : 
     703          532 :   for (auto & submodel : registered_models())
     704           14 :     for (auto && [pname, nl_param] : submodel->named_nonlinear_parameters(true))
     705           14 :       all_nl_params[submodel->name() + settings().parameter_name_separator() + pname] = nl_param;
     706              : 
     707          518 :   return all_nl_params;
     708          518 : }
     709              : 
     710              : std::set<VariableName>
     711          260 : Model::consumed_items() const
     712              : {
     713          260 :   auto items = input_axis().variable_names();
     714          520 :   return {items.begin(), items.end()};
     715          260 : }
     716              : 
     717              : std::set<VariableName>
     718          260 : Model::provided_items() const
     719              : {
     720          260 :   auto items = output_axis().variable_names();
     721          520 :   return {items.begin(), items.end()};
     722          260 : }
     723              : 
     724              : void
     725          454 : Model::assign_input_stack(jit::Stack & stack)
     726              : {
     727              : #ifndef NDEBUG
     728          454 :   const auto nstack = input_axis().nvariable() + host<ParameterStore>()->named_parameters().size();
     729          454 :   neml_assert_dbg(
     730          454 :       stack.size() == nstack,
     731              :       "Stack size (",
     732          454 :       stack.size(),
     733              :       ") must equal to the number of input variables, parameters, and buffers in the model (",
     734              :       nstack,
     735              :       ").");
     736              : #endif
     737              : 
     738          454 :   assign_parameter_stack(stack);
     739          454 :   VariableStore::assign_input_stack(stack);
     740          454 : }
     741              : 
     742              : jit::Stack
     743         6968 : Model::collect_input_stack() const
     744              : {
     745         6968 :   auto stack = VariableStore::collect_input_stack();
     746         6968 :   const auto param_stack = collect_parameter_stack();
     747              : 
     748              :   // Recall stack is first in last out.
     749              :   // Parameter stack go after (on top of) input variables. This means that in assign_input_stack
     750              :   // we need to pop parameters first, then input variables.
     751         6968 :   stack.insert(stack.end(), param_stack.begin(), param_stack.end());
     752        13936 :   return stack;
     753         6968 : }
     754              : 
     755              : void
     756         1572 : Model::set_guess(const Sol<false> & x)
     757              : {
     758         1572 :   const auto sol_assember = VectorAssembler(input_axis().subaxis(STATE));
     759         1572 :   assign_input(sol_assember.split_by_variable(x));
     760         1572 : }
     761              : 
     762              : void
     763         2823 : Model::assemble(NonlinearSystem::Res<false> * residual, NonlinearSystem::Jac<false> * Jacobian)
     764              : {
     765         2823 :   forward_maybe_jit(residual, Jacobian, false);
     766              : 
     767         2823 :   if (residual)
     768              :   {
     769         1572 :     const auto res_assembler = VectorAssembler(output_axis().subaxis(RESIDUAL));
     770         1572 :     *residual = Res<false>(res_assembler.assemble_by_variable(collect_output()));
     771              :   }
     772         2823 :   if (Jacobian)
     773              :   {
     774              :     const auto jac_assembler =
     775         1251 :         MatrixAssembler(output_axis().subaxis(RESIDUAL), input_axis().subaxis(STATE));
     776         1251 :     *Jacobian = Jac<false>(jac_assembler.assemble_by_variable(collect_output_derivatives()));
     777              :   }
     778         2823 : }
     779              : 
     780              : bool
     781          266 : Model::AD_need_value(bool dout, bool d2out) const
     782              : {
     783          266 :   if (dout)
     784          203 :     if (!_ad_derivs.empty())
     785            2 :       return true;
     786              : 
     787          264 :   if (d2out)
     788           66 :     for (auto && [y, u1u2s] : _ad_secderivs)
     789            0 :       for (auto && [u1, u2s] : u1u2s)
     790            0 :         if (_ad_derivs.count(y) && _ad_derivs.at(y).count(u1))
     791            0 :           return true;
     792              : 
     793          264 :   return false;
     794              : }
     795              : 
     796              : void
     797          547 : Model::enable_AD()
     798              : {
     799          553 :   for (auto * ad_arg : _ad_args)
     800            6 :     ad_arg->requires_grad_();
     801          547 : }
     802              : 
     803              : void
     804          547 : Model::extract_AD_derivatives(bool dout, bool d2out)
     805              : {
     806          547 :   neml_assert(dout || d2out, "At least one of the output derivatives must be requested.");
     807              : 
     808          551 :   for (auto && [y, us] : _ad_derivs)
     809              :   {
     810            4 :     if (!dout && d2out)
     811            0 :       if (!_ad_secderivs.count(y))
     812            0 :         continue;
     813              : 
     814              :     // Gather all dependent variables
     815            4 :     std::vector<Tensor> uts;
     816           18 :     for (const auto * u : us)
     817           14 :       if (u->is_dependent())
     818           14 :         uts.push_back(u->tensor());
     819              : 
     820              :     // Check if we need to create the graph (i.e., if any of the second derivatives are requested)
     821            4 :     bool create_graph = false;
     822           18 :     for (const auto * u : us)
     823           14 :       if (u->is_dependent())
     824           14 :         if (!create_graph && !dout && d2out)
     825            0 :           if (_ad_secderivs.at(y).count(u))
     826            0 :             create_graph = true;
     827              : 
     828            4 :     const auto dy_dus = jacrev(y->tensor(),
     829              :                                uts,
     830              :                                /*retain_graph=*/true,
     831              :                                /*create_graph=*/create_graph,
     832            4 :                                /*allow_unused=*/true);
     833              : 
     834            4 :     std::size_t i = 0;
     835           18 :     for (const auto * u : us)
     836           14 :       if (u->is_dependent())
     837              :       {
     838           14 :         if (dy_dus[i].defined())
     839           14 :           y->d(*u) = dy_dus[i];
     840           14 :         i++;
     841              :       }
     842            4 :   }
     843              : 
     844          547 :   if (d2out)
     845              :   {
     846          114 :     for (auto && [y, u1u2s] : _ad_secderivs)
     847            0 :       for (auto && [u1, u2s] : u1u2s)
     848              :       {
     849            0 :         if (!u1->is_dependent())
     850            0 :           continue;
     851              : 
     852            0 :         const auto & dy_du1 = y->derivatives()[u1->name()];
     853              : 
     854            0 :         if (!dy_du1.defined() || !dy_du1.requires_grad())
     855            0 :           continue;
     856              : 
     857            0 :         std::vector<Tensor> u2ts;
     858            0 :         for (const auto * u2 : u2s)
     859            0 :           if (u2->is_dependent())
     860            0 :             u2ts.push_back(u2->tensor());
     861              : 
     862              :         const auto d2y_du1u2s = jacrev(dy_du1,
     863              :                                        u2ts,
     864              :                                        /*retain_graph=*/true,
     865              :                                        /*create_graph=*/false,
     866            0 :                                        /*allow_unused=*/true);
     867              : 
     868            0 :         std::size_t i = 0;
     869            0 :         for (const auto * u2 : u2s)
     870            0 :           if (u2->is_dependent())
     871              :           {
     872            0 :             if (d2y_du1u2s[i].defined())
     873            0 :               y->d(*u1, *u2) = d2y_du1u2s[i];
     874            0 :             i++;
     875              :           }
     876            0 :       }
     877              :   }
     878          547 : }
     879              : 
     880              : // LCOV_EXCL_START
     881              : std::ostream &
     882              : operator<<(std::ostream & os, const Model & model)
     883              : {
     884              :   bool first = false;
     885              :   const std::string tab = "            ";
     886              : 
     887              :   os << "Name:       " << model.name() << '\n';
     888              : 
     889              :   if (!model.input_variables().empty())
     890              :   {
     891              :     os << "Input:      ";
     892              :     first = true;
     893              :     for (auto && [name, var] : model.input_variables())
     894              :     {
     895              :       os << (first ? "" : tab);
     896              :       os << name << " [" << var->type() << "]\n";
     897              :       first = false;
     898              :     }
     899              :   }
     900              : 
     901              :   if (!model.input_variables().empty())
     902              :   {
     903              :     os << "Output:     ";
     904              :     first = true;
     905              :     for (auto && [name, var] : model.output_variables())
     906              :     {
     907              :       os << (first ? "" : tab);
     908              :       os << name << " [" << var->type() << "]\n";
     909              :       first = false;
     910              :     }
     911              :   }
     912              : 
     913              :   if (!model.named_parameters().empty())
     914              :   {
     915              :     os << "Parameters: ";
     916              :     first = true;
     917              :     for (auto && [name, param] : model.named_parameters())
     918              :     {
     919              :       os << (first ? "" : tab);
     920              :       os << name << " [" << param->type() << "][" << Tensor(*param).scalar_type() << "]["
     921              :          << Tensor(*param).device() << "]\n";
     922              :       first = false;
     923              :     }
     924              :   }
     925              : 
     926              :   if (!model.named_buffers().empty())
     927              :   {
     928              :     os << "Buffers:    ";
     929              :     first = true;
     930              :     for (auto && [name, buffer] : model.named_buffers())
     931              :     {
     932              :       os << (first ? "" : tab);
     933              :       os << name << " [" << buffer->type() << "][" << Tensor(*buffer).scalar_type() << "]["
     934              :          << Tensor(*buffer).device() << "]\n";
     935              :       first = false;
     936              :     }
     937              :   }
     938              : 
     939              :   return os;
     940              : }
     941              : // LCOV_EXCL_STOP
     942              : } // namespace neml2
        

Generated by: LCOV version 2.0-1