LCOV - code coverage report
Current view: top level - models - ParameterStore.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 77.8 % 135 105
Test Date: 2025-10-02 16:03:03 Functions: 18.6 % 97 18

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/models/ParameterStore.h"
      26              : #include "neml2/models/Model.h"
      27              : #include "neml2/models/InputParameter.h"
      28              : #include "neml2/misc/assertions.h"
      29              : #include "neml2/base/TensorName.h"
      30              : #include "neml2/base/Parser.h"
      31              : #include "neml2/base/Settings.h"
      32              : #include "neml2/tensors/tensors.h"
      33              : #include "neml2/tensors/TensorValue.h"
      34              : 
      35              : namespace neml2
      36              : {
      37          481 : ParameterStore::ParameterStore(Model * object)
      38          481 :   : _object(object)
      39              : {
      40          481 : }
      41              : 
      42              : void
      43           91 : ParameterStore::send_parameters_to(const TensorOptions & options)
      44              : {
      45          130 :   for (auto && [name, param] : _param_values)
      46           39 :     param->to_(options);
      47           91 : }
      48              : 
      49              : void
      50          804 : ParameterStore::set_parameter(const std::string & name, const Tensor & value)
      51              : {
      52          804 :   neml_assert(_object->host() == _object, "This method should only be called on the host model.");
      53          804 :   neml_assert(named_parameters().count(name), "There is no parameter named ", name);
      54          804 :   *named_parameters()[name] = value;
      55          804 : }
      56              : 
      57              : TensorValueBase &
      58            7 : ParameterStore::get_parameter(const std::string & name)
      59              : {
      60            7 :   neml_assert(_object->host() == _object, "This method should only be called on the host model.");
      61            7 :   neml_assert(_param_values.count(name), "Parameter named ", name, " does not exist.");
      62            7 :   return *_param_values[name];
      63              : }
      64              : 
      65              : void
      66            0 : ParameterStore::set_parameters(const std::map<std::string, Tensor> & param_values)
      67              : {
      68            0 :   for (const auto & [name, value] : param_values)
      69            0 :     set_parameter(name, value);
      70            0 : }
      71              : 
      72              : std::map<std::string, std::unique_ptr<TensorValueBase>> &
      73       175227 : ParameterStore::named_parameters()
      74              : {
      75       175227 :   neml_assert(_object->host() == _object,
      76              :               "named_parameters() should only be called on the host model.");
      77       175227 :   return _param_values;
      78              : }
      79              : 
      80              : template <typename T, typename>
      81              : const T &
      82          350 : ParameterStore::declare_parameter(const std::string & name, const T & rawval)
      83              : {
      84          350 :   if (_object->host() != _object)
      85          264 :     return _object->host<ParameterStore>()->declare_parameter(
      86          176 :         _object->name() + _object->settings().parameter_name_separator() + name, rawval);
      87              : 
      88          262 :   TensorValueBase * base_ptr = nullptr;
      89              : 
      90              :   // If the parameter already exists, get it
      91          262 :   if (_param_values.count(name))
      92            5 :     base_ptr = &get_parameter(name);
      93              :   // If the parameter doesn't exist, create it
      94              :   else
      95              :   {
      96          257 :     auto val = std::make_unique<TensorValue<T>>(rawval);
      97          257 :     auto [it, success] = _param_values.emplace(name, std::move(val));
      98          257 :     base_ptr = it->second.get();
      99          257 :   }
     100              : 
     101          262 :   auto ptr = dynamic_cast<TensorValue<T> *>(base_ptr);
     102          262 :   neml_assert(ptr, "Internal error: Failed to cast parameter to a concrete type.");
     103          262 :   return ptr->value();
     104              : }
     105              : 
     106              : template <typename T>
     107              : const T &
     108           78 : resolve_tensor_name(const TensorName<T> & tn, Model * caller, const std::string & pname)
     109              : {
     110           78 :   if (!caller)
     111            0 :     throw ParserException("A non-nullptr caller must be provided to resolve a tensor name");
     112              : 
     113              :   if constexpr (std::is_same_v<T, ATensor> || std::is_same_v<T, Tensor>)
     114              :     throw ParserException("ATensr and Tensor cannot be resolved to a model output variable");
     115              : 
     116              :   // When we retrieve a model, we want it to register its own parameters and buffers in the
     117              :   // host of the caller.
     118           78 :   OptionSet extra_opts;
     119          156 :   extra_opts.set<NEML2Object *>("_host") = caller->host();
     120              : 
     121              :   // The raw string is interpreted as a _variable specifier_ which takes three possible forms
     122              :   // 1. "model_name.variable_name"
     123              :   // 2. "model_name"
     124              :   // 3. "variable_name"
     125           78 :   std::shared_ptr<Model> provider = nullptr;
     126           78 :   VariableName var_name;
     127              : 
     128              :   // Split the raw string into tokens with the delimiter '.'
     129              :   // There must be either one or two tokens
     130           78 :   auto tokens = utils::split(tn.raw(), ".");
     131           78 :   if (tokens.size() != 1 && tokens.size() != 2)
     132            0 :     throw ParserException("Invalid variable specifier '" + tn.raw() +
     133              :                           "'. It should take the form 'model_name', 'variable_name', or "
     134              :                           "'model_name.variable_name'");
     135              : 
     136              :   // When there is only one token, it must be a model name, and the model must have one and only
     137              :   // one output variable.
     138           78 :   if (tokens.size() == 1)
     139              :   {
     140              :     // Try to parse it as a model name
     141           78 :     const auto & mname = tokens[0];
     142              :     try
     143              :     {
     144              :       // Get the model
     145          223 :       provider =
     146              :           caller->factory()->get_object<Model>("Models", mname, extra_opts, /*force_create=*/false);
     147              : 
     148              :       // Apparently, the model must have one and only one output variable.
     149           11 :       const auto nvar = provider->output_axis().nvariable();
     150           11 :       if (nvar == 0)
     151            0 :         throw ParserException(
     152            0 :             "Invalid variable specifier '" + tn.raw() +
     153              :             "' (interpreted as model name). The model does not define any output variable.");
     154           11 :       if (nvar > 1)
     155            0 :         throw ParserException(
     156            0 :             "Invalid variable specifier '" + tn.raw() +
     157              :             "' (interpreted as model name). The model must have one and only one output "
     158              :             "variable. However, it has " +
     159              :             utils::stringify(nvar) +
     160              :             " output variables. To disambiguite, please specify the variable name using "
     161              :             "format 'model_name.variable_name'. The model's output axis is:\n" +
     162            0 :             utils::stringify(provider->output_axis()));
     163              : 
     164              :       // Retrieve the output variable
     165           11 :       var_name = provider->output_axis().variable_names()[0];
     166              :     }
     167              :     // Try to parse it as a variable name
     168          134 :     catch (const FactoryException & err_model)
     169              :     {
     170           67 :       auto success = utils::parse_<VariableName>(var_name, tokens[0]);
     171           67 :       if (!success)
     172            0 :         throw ParserException(
     173            0 :             "Invalid variable specifier '" + tn.raw() +
     174              :             "'. It should take the form 'model_name', 'variable_name', or "
     175              :             "'model_name.variable_name'. Since there is no '.' delimiter, it can either be a "
     176              :             "model name or a variable name. Interpreting it as a model name failed with error "
     177              :             "message: " +
     178            0 :             err_model.what() + ". It also cannot be parsed as a valid variable name.");
     179              : 
     180              :       // Create a dummy model that defines this parameter
     181           67 :       const auto obj_name = "__parameter_" + var_name.str() + "__";
     182           67 :       const auto obj_type = utils::demangle(typeid(T).name()).substr(7) + "InputParameter";
     183           67 :       auto options = InputParameter<T>::expected_options();
     184          134 :       options.template set<std::string>("name") = obj_name;
     185          134 :       options.template set<std::string>("type") = obj_type;
     186          134 :       options.template set<VariableName>("from") = var_name;
     187          268 :       options.template set<VariableName>("to") = var_name.with_suffix("_autogenerated");
     188           67 :       options.set("to").user_specified() = true;
     189           67 :       options.name() = obj_name;
     190           67 :       options.type() = obj_type;
     191          201 :       if (caller->factory()->input_file()["Models"].count(obj_name))
     192              :       {
     193            4 :         const auto & existing_options = caller->factory()->input_file()["Models"][obj_name];
     194            2 :         if (!options_compatible(existing_options, options))
     195            0 :           throw ParserException(
     196              :               "Option clash when declaring an input parameter. Existing options:\n" +
     197              :               utils::stringify(existing_options) + ". New options:\n" + utils::stringify(options));
     198              :       }
     199              :       else
     200          195 :         caller->factory()->input_file()["Models"][obj_name] = std::move(options);
     201              : 
     202              :       // Get the model
     203          134 :       provider = caller->factory()->get_object<Model>(
     204              :           "Models", obj_name, extra_opts, /*force_create=*/false);
     205              : 
     206              :       // Retrieve the output variable
     207           67 :       var_name = provider->output_axis().variable_names()[0];
     208           67 :     }
     209              :   }
     210              :   else
     211              :   {
     212              :     // The first token is the model name
     213            0 :     const auto & mname = tokens[0];
     214              : 
     215              :     // Get the model
     216            0 :     provider =
     217              :         caller->factory()->get_object<Model>("Models", mname, extra_opts, /*force_create=*/false);
     218              : 
     219              :     // The second token is the variable name
     220            0 :     auto success = utils::parse_<VariableName>(var_name, tokens[1]);
     221            0 :     if (!success)
     222            0 :       throw ParserException("Invalid variable specifier '" + tn.raw() + "'. '" + tokens[1] +
     223              :                             "' cannot be parsed as a valid variable name.");
     224            0 :     if (!provider->output_axis().has_variable(var_name))
     225            0 :       throw ParserException("Invalid variable specifier '" + tn.raw() + "'. Model '" + mname +
     226              :                             "' does not have an output variable named '" +
     227              :                             utils::stringify(var_name) + "'");
     228              :   }
     229              : 
     230              :   // Declare the input variable
     231           78 :   caller->declare_input_variable<T>(var_name, {}, /*allow_duplicate=*/true);
     232              : 
     233              :   // Get the variable
     234           78 :   const auto * var = &provider->output_variable(var_name);
     235           78 :   const auto * var_ptr = dynamic_cast<const Variable<T> *>(var);
     236           78 :   if (!var_ptr)
     237            0 :     throw ParserException("The variable specifier '" + tn.raw() +
     238              :                           "' is valid, but the variable cannot be cast to type " +
     239              :                           utils::demangle(typeid(T).name()));
     240              : 
     241              :   // For bookkeeping, the caller shall record the model that provides this variable
     242              :   // This is needed for two reasons:
     243              :   //   1. When the caller is composed with others, we need this information to automatically
     244              :   //      bring in the provider.
     245              :   //   2. When the caller is sent to a different device/dtype, the caller needs to forward the
     246              :   //      call to the provider.
     247           78 :   caller->register_nonlinear_parameter(pname, NonlinearParameter{provider, var_name, var_ptr});
     248              : 
     249              :   // Done!
     250          156 :   return var_ptr->value();
     251          156 : }
     252              : 
     253              : template <typename T, typename>
     254              : const T &
     255          313 : ParameterStore::declare_parameter(const std::string & name,
     256              :                                   const TensorName<T> & tensorname,
     257              :                                   bool allow_nonlinear)
     258              : {
     259          313 :   auto * factory = _object->factory();
     260          313 :   neml_assert(factory, "Internal error: factory != nullptr");
     261              : 
     262              :   try
     263              :   {
     264          313 :     return declare_parameter(name, tensorname.resolve(factory));
     265              :   }
     266          156 :   catch (const SetupException & err_tensor)
     267              :   {
     268           78 :     if (allow_nonlinear)
     269              :       try
     270              :       {
     271           78 :         return resolve_tensor_name(tensorname, _object, name);
     272              :       }
     273            0 :       catch (const SetupException & err_var)
     274              :       {
     275            0 :         throw ParserException(std::string(err_tensor.what()) +
     276              :                               "\nAn additional attempt was made to interpret the tensor name as a "
     277              :                               "variable specifier, but it failed with error message:" +
     278            0 :                               err_var.what());
     279              :       }
     280              :     else
     281            0 :       throw ParserException(
     282            0 :           std::string(err_tensor.what()) +
     283              :           "\nThe tensor name cannot be interpreted as a variable specifier because variable "
     284              :           "coupling has not been implemented for this parameter. If this is intended, please "
     285              :           "consider opening an issue on the NEML2 GitHub repository.");
     286              :   }
     287              : }
     288              : 
     289              : template <typename T, typename>
     290              : const T &
     291          223 : ParameterStore::declare_parameter(const std::string & name,
     292              :                                   const std::string & input_option_name,
     293              :                                   bool allow_nonlinear)
     294              : {
     295          223 :   if (_object->input_options().contains(input_option_name))
     296          446 :     return declare_parameter<T>(
     297          669 :         name, _object->input_options().get<TensorName<T>>(input_option_name), allow_nonlinear);
     298              : 
     299            0 :   throw NEMLException("Trying to register parameter named " + name + " from input option named " +
     300              :                       input_option_name + " of type " + utils::demangle(typeid(T).name()) +
     301              :                       ". Make sure you provided the correct parameter name, option name, and "
     302              :                       "parameter type.");
     303              : }
     304              : 
     305              : #define PARAMETERSTORE_INTANTIATE_TENSORBASE(T)                                                    \
     306              :   template const T & ParameterStore::declare_parameter<T>(const std::string &, const T &)
     307              : FOR_ALL_TENSORBASE(PARAMETERSTORE_INTANTIATE_TENSORBASE);
     308              : 
     309              : #define PARAMETERSTORE_INTANTIATE_PRIMITIVETENSOR(T)                                               \
     310              :   template const T & ParameterStore::declare_parameter<T>(                                         \
     311              :       const std::string &, const TensorName<T> &, bool);                                           \
     312              :   template const T & ParameterStore::declare_parameter<T>(                                         \
     313              :       const std::string &, const std::string &, bool)
     314              : FOR_ALL_PRIMITIVETENSOR(PARAMETERSTORE_INTANTIATE_PRIMITIVETENSOR);
     315              : 
     316              : void
     317          454 : ParameterStore::assign_parameter_stack(jit::Stack & stack)
     318              : {
     319          454 :   const auto & params = _object->host<ParameterStore>()->named_parameters();
     320              : 
     321          454 :   neml_assert_dbg(stack.size() >= params.size(),
     322              :                   "Stack size (",
     323          454 :                   stack.size(),
     324              :                   ") is smaller than the number of parameters in the model (",
     325          454 :                   params.size(),
     326              :                   ").");
     327              : 
     328              :   // Last n tensors in the stack are the parameters
     329          454 :   std::size_t i = stack.size() - params.size();
     330         1032 :   for (auto && [name, param] : params)
     331              :   {
     332          578 :     const auto tensor = stack[i++].toTensor();
     333          578 :     *param = Tensor(tensor, tensor.dim() - Tensor(*param).base_dim());
     334          578 :   }
     335              : 
     336              :   // Drop the input variables from the stack
     337          454 :   jit::drop(stack, params.size());
     338          454 : }
     339              : 
     340              : jit::Stack
     341         6968 : ParameterStore::collect_parameter_stack() const
     342              : {
     343         6968 :   const auto & params = _object->host<ParameterStore>()->named_parameters();
     344         6968 :   jit::Stack stack;
     345         6968 :   stack.reserve(params.size());
     346        23609 :   for (auto && [name, param] : params)
     347        16641 :     stack.emplace_back(Tensor(*param));
     348         6968 :   return stack;
     349            0 : }
     350              : } // namespace neml2
        

Generated by: LCOV version 2.0-1