LCOV - code coverage report
Current view: top level - models - ParameterStore.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 76.5 % 136 104
Test Date: 2025-06-29 01:25:44 Functions: 18.6 % 97 18

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/models/ParameterStore.h"
      26              : #include "neml2/models/Model.h"
      27              : #include "neml2/models/InputParameter.h"
      28              : #include "neml2/misc/assertions.h"
      29              : #include "neml2/base/TensorName.h"
      30              : #include "neml2/base/Parser.h"
      31              : #include "neml2/base/Settings.h"
      32              : #include "neml2/tensors/tensors.h"
      33              : #include "neml2/tensors/TensorValue.h"
      34              : 
      35              : namespace neml2
      36              : {
      37          432 : ParameterStore::ParameterStore(Model * object)
      38          432 :   : _object(object)
      39              : {
      40          432 : }
      41              : 
      42              : void
      43           91 : ParameterStore::send_parameters_to(const TensorOptions & options)
      44              : {
      45          130 :   for (auto && [name, param] : _param_values)
      46           39 :     param->to_(options);
      47           91 : }
      48              : 
      49              : void
      50          668 : ParameterStore::set_parameter(const std::string & name, const Tensor & value)
      51              : {
      52          668 :   neml_assert(_object->host() == _object, "This method should only be called on the host model.");
      53          668 :   neml_assert(named_parameters().count(name), "There is no parameter named ", name);
      54          668 :   *named_parameters()[name] = value;
      55          668 : }
      56              : 
      57              : TensorValueBase &
      58            7 : ParameterStore::get_parameter(const std::string & name)
      59              : {
      60            7 :   neml_assert(_object->host() == _object, "This method should only be called on the host model.");
      61            7 :   neml_assert(_param_values.count(name), "Parameter named ", name, " does not exist.");
      62            7 :   return *_param_values[name];
      63              : }
      64              : 
      65              : void
      66            0 : ParameterStore::set_parameters(const std::map<std::string, Tensor> & param_values)
      67              : {
      68            0 :   for (const auto & [name, value] : param_values)
      69            0 :     set_parameter(name, value);
      70            0 : }
      71              : 
      72              : std::map<std::string, std::unique_ptr<TensorValueBase>> &
      73       161471 : ParameterStore::named_parameters()
      74              : {
      75       161471 :   neml_assert(_object->host() == _object,
      76              :               "named_parameters() should only be called on the host model.");
      77       161471 :   return _param_values;
      78              : }
      79              : 
      80              : template <typename T, typename>
      81              : const T &
      82          293 : ParameterStore::declare_parameter(const std::string & name, const T & rawval)
      83              : {
      84          293 :   if (_object->host() != _object)
      85          234 :     return _object->host<ParameterStore>()->declare_parameter(
      86          156 :         _object->name() + _object->settings().parameter_name_separator() + name, rawval);
      87              : 
      88          215 :   TensorValueBase * base_ptr = nullptr;
      89              : 
      90              :   // If the parameter already exists, get it
      91          215 :   if (_param_values.count(name))
      92            5 :     base_ptr = &get_parameter(name);
      93              :   // If the parameter doesn't exist, create it
      94              :   else
      95              :   {
      96          210 :     auto val = std::make_unique<TensorValue<T>>(rawval);
      97          210 :     auto [it, success] = _param_values.emplace(name, std::move(val));
      98          210 :     base_ptr = it->second.get();
      99          210 :   }
     100              : 
     101          215 :   auto ptr = dynamic_cast<TensorValue<T> *>(base_ptr);
     102          215 :   neml_assert(ptr, "Internal error: Failed to cast parameter to a concrete type.");
     103          215 :   return ptr->value();
     104              : }
     105              : 
     106              : template <typename T>
     107              : const T &
     108           66 : resolve_tensor_name(const TensorName<T> & tn, Model * caller, const std::string & pname)
     109              : {
     110           66 :   if (!caller)
     111            0 :     throw ParserException("A non-nullptr caller must be provided to resolve a tensor name");
     112              : 
     113              :   if constexpr (std::is_same_v<T, ATensor> || std::is_same_v<T, Tensor>)
     114              :   {
     115              :     throw ParserException("ATensr and Tensor cannot be resolved to a model output variable");
     116              :   }
     117              :   else
     118              :   {
     119              :     // When we retrieve a model, we want it to register its own parameters and buffers in the
     120              :     // host of the caller.
     121           66 :     OptionSet extra_opts;
     122          132 :     extra_opts.set<NEML2Object *>("_host") = caller->host();
     123              : 
     124              :     // The raw string is interpreted as a _variable specifier_ which takes three possible forms
     125              :     // 1. "model_name.variable_name"
     126              :     // 2. "model_name"
     127              :     // 3. "variable_name"
     128           66 :     std::shared_ptr<Model> provider = nullptr;
     129           66 :     VariableName var_name;
     130              : 
     131              :     // Split the raw string into tokens with the delimiter '.'
     132              :     // There must be either one or two tokens
     133           66 :     auto tokens = utils::split(tn.raw(), ".");
     134           66 :     if (tokens.size() != 1 && tokens.size() != 2)
     135            0 :       throw ParserException("Invalid variable specifier '" + tn.raw() +
     136              :                             "'. It should take the form 'model_name', 'variable_name', or "
     137              :                             "'model_name.variable_name'");
     138              : 
     139              :     // When there is only one token, it must be a model name, and the model must have one and only
     140              :     // one output variable.
     141           66 :     if (tokens.size() == 1)
     142              :     {
     143              :       // Try to parse it as a model name
     144           66 :       const auto & mname = tokens[0];
     145              :       try
     146              :       {
     147              :         // Get the model
     148          189 :         provider = caller->factory()->get_object<Model>(
     149              :             "Models", mname, extra_opts, /*force_create=*/false);
     150              : 
     151              :         // Apparently, the model must have one and only one output variable.
     152            9 :         const auto nvar = provider->output_axis().nvariable();
     153            9 :         if (nvar == 0)
     154            0 :           throw ParserException(
     155            0 :               "Invalid variable specifier '" + tn.raw() +
     156              :               "' (interpreted as model name). The model does not define any output variable.");
     157            9 :         if (nvar > 1)
     158            0 :           throw ParserException(
     159            0 :               "Invalid variable specifier '" + tn.raw() +
     160              :               "' (interpreted as model name). The model must have one and only one output "
     161              :               "variable. However, it has " +
     162              :               utils::stringify(nvar) +
     163              :               " output variables. To disambiguite, please specify the variable name using "
     164              :               "format 'model_name.variable_name'. The model's output axis is:\n" +
     165            0 :               utils::stringify(provider->output_axis()));
     166              : 
     167              :         // Retrieve the output variable
     168            9 :         var_name = provider->output_axis().variable_names()[0];
     169              :       }
     170              :       // Try to parse it as a variable name
     171          114 :       catch (const FactoryException & err_model)
     172              :       {
     173           57 :         auto success = utils::parse_<VariableName>(var_name, tokens[0]);
     174           57 :         if (!success)
     175            0 :           throw ParserException(
     176            0 :               "Invalid variable specifier '" + tn.raw() +
     177              :               "'. It should take the form 'model_name', 'variable_name', or "
     178              :               "'model_name.variable_name'. Since there is no '.' delimiter, it can either be a "
     179              :               "model name or a variable name. Interpreting it as a model name failed with error "
     180              :               "message: " +
     181            0 :               err_model.what() + ". It also cannot be parsed as a valid variable name.");
     182              : 
     183              :         // Create a dummy model that defines this parameter
     184           57 :         const auto obj_name = "__parameter_" + var_name.str() + "__";
     185           57 :         const auto obj_type = utils::demangle(typeid(T).name()).substr(7) + "InputParameter";
     186           57 :         auto options = InputParameter<T>::expected_options();
     187          114 :         options.template set<std::string>("name") = obj_name;
     188          114 :         options.template set<std::string>("type") = obj_type;
     189          114 :         options.template set<VariableName>("from") = var_name;
     190          228 :         options.template set<VariableName>("to") = var_name.with_suffix("_autogenerated");
     191           57 :         options.set("to").user_specified() = true;
     192           57 :         options.name() = obj_name;
     193           57 :         options.type() = obj_type;
     194          171 :         if (caller->factory()->input_file()["Models"].count(obj_name))
     195              :         {
     196            0 :           const auto & existing_options = caller->factory()->input_file()["Models"][obj_name];
     197            0 :           if (!options_compatible(existing_options, options))
     198            0 :             throw ParserException(
     199              :                 "Option clash when declaring an input parameter. Existing options:\n" +
     200              :                 utils::stringify(existing_options) + ". New options:\n" +
     201              :                 utils::stringify(options));
     202              :         }
     203              :         else
     204          171 :           caller->factory()->input_file()["Models"][obj_name] = std::move(options);
     205              : 
     206              :         // Get the model
     207          114 :         provider = caller->factory()->get_object<Model>(
     208              :             "Models", obj_name, extra_opts, /*force_create=*/false);
     209              : 
     210              :         // Retrieve the output variable
     211           57 :         var_name = provider->output_axis().variable_names()[0];
     212           57 :       }
     213              :     }
     214              :     else
     215              :     {
     216              :       // The first token is the model name
     217            0 :       const auto & mname = tokens[0];
     218              : 
     219              :       // Get the model
     220            0 :       provider =
     221              :           caller->factory()->get_object<Model>("Models", mname, extra_opts, /*force_create=*/false);
     222              : 
     223              :       // The second token is the variable name
     224            0 :       auto success = utils::parse_<VariableName>(var_name, tokens[1]);
     225            0 :       if (!success)
     226            0 :         throw ParserException("Invalid variable specifier '" + tn.raw() + "'. '" + tokens[1] +
     227              :                               "' cannot be parsed as a valid variable name.");
     228            0 :       if (!provider->output_axis().has_variable(var_name))
     229            0 :         throw ParserException("Invalid variable specifier '" + tn.raw() + "'. Model '" + mname +
     230              :                               "' does not have an output variable named '" +
     231              :                               utils::stringify(var_name) + "'");
     232              :     }
     233              : 
     234              :     // Declare the input variable
     235           66 :     caller->declare_input_variable<T>(var_name);
     236              : 
     237              :     // Get the variable
     238           66 :     const auto * var = &provider->output_variable(var_name);
     239           66 :     const auto * var_ptr = dynamic_cast<const Variable<T> *>(var);
     240           66 :     if (!var_ptr)
     241            0 :       throw ParserException("The variable specifier '" + tn.raw() +
     242              :                             "' is valid, but the variable cannot be cast to type " +
     243              :                             utils::demangle(typeid(T).name()));
     244              : 
     245              :     // For bookkeeping, the caller shall record the model that provides this variable
     246              :     // This is needed for two reasons:
     247              :     //   1. When the caller is composed with others, we need this information to automatically
     248              :     //      bring in the provider.
     249              :     //   2. When the caller is sent to a different device/dtype, the caller needs to forward the
     250              :     //      call to the provider.
     251           66 :     caller->register_nonlinear_parameter(pname, NonlinearParameter{provider, var_name, var_ptr});
     252              : 
     253              :     // Done!
     254          132 :     return var_ptr->value();
     255           66 :   }
     256           66 : }
     257              : 
     258              : template <typename T, typename>
     259              : const T &
     260          257 : ParameterStore::declare_parameter(const std::string & name,
     261              :                                   const TensorName<T> & tensorname,
     262              :                                   bool allow_nonlinear)
     263              : {
     264          257 :   auto * factory = _object->factory();
     265          257 :   neml_assert(factory, "Internal error: factory != nullptr");
     266              : 
     267              :   try
     268              :   {
     269          257 :     return declare_parameter(name, tensorname.resolve(factory));
     270              :   }
     271          132 :   catch (const SetupException & err_tensor)
     272              :   {
     273           66 :     if (allow_nonlinear)
     274              :       try
     275              :       {
     276           66 :         return resolve_tensor_name(tensorname, _object, name);
     277              :       }
     278            0 :       catch (const SetupException & err_var)
     279              :       {
     280            0 :         throw ParserException(std::string(err_tensor.what()) +
     281              :                               "\nAn additional attempt was made to interpret the tensor name as a "
     282              :                               "variable specifier, but it failed with error message:" +
     283            0 :                               err_var.what());
     284              :       }
     285              :     else
     286            0 :       throw ParserException(
     287            0 :           std::string(err_tensor.what()) +
     288              :           "\nThe tensor name cannot be interpreted as a variable specifier because variable "
     289              :           "coupling has not been implemented for this parameter. If this is intended, please "
     290              :           "consider opening an issue on the NEML2 GitHub repository.");
     291              :   }
     292              : }
     293              : 
     294              : template <typename T, typename>
     295              : const T &
     296          191 : ParameterStore::declare_parameter(const std::string & name,
     297              :                                   const std::string & input_option_name,
     298              :                                   bool allow_nonlinear)
     299              : {
     300          191 :   if (_object->input_options().contains(input_option_name))
     301          382 :     return declare_parameter<T>(
     302          573 :         name, _object->input_options().get<TensorName<T>>(input_option_name), allow_nonlinear);
     303              : 
     304            0 :   throw NEMLException("Trying to register parameter named " + name + " from input option named " +
     305              :                       input_option_name + " of type " + utils::demangle(typeid(T).name()) +
     306              :                       ". Make sure you provided the correct parameter name, option name, and "
     307              :                       "parameter type.");
     308              : }
     309              : 
     310              : #define PARAMETERSTORE_INTANTIATE_TENSORBASE(T)                                                    \
     311              :   template const T & ParameterStore::declare_parameter<T>(const std::string &, const T &)
     312              : FOR_ALL_TENSORBASE(PARAMETERSTORE_INTANTIATE_TENSORBASE);
     313              : 
     314              : #define PARAMETERSTORE_INTANTIATE_PRIMITIVETENSOR(T)                                               \
     315              :   template const T & ParameterStore::declare_parameter<T>(                                         \
     316              :       const std::string &, const TensorName<T> &, bool);                                           \
     317              :   template const T & ParameterStore::declare_parameter<T>(                                         \
     318              :       const std::string &, const std::string &, bool)
     319              : FOR_ALL_PRIMITIVETENSOR(PARAMETERSTORE_INTANTIATE_PRIMITIVETENSOR);
     320              : 
     321              : void
     322          394 : ParameterStore::assign_parameter_stack(jit::Stack & stack)
     323              : {
     324          394 :   const auto & params = _object->host<ParameterStore>()->named_parameters();
     325              : 
     326          394 :   neml_assert_dbg(stack.size() >= params.size(),
     327              :                   "Stack size (",
     328          394 :                   stack.size(),
     329              :                   ") is smaller than the number of parameters in the model (",
     330          394 :                   params.size(),
     331              :                   ").");
     332              : 
     333              :   // Last n tensors in the stack are the parameters
     334          394 :   std::size_t i = stack.size() - params.size();
     335          856 :   for (auto && [name, param] : params)
     336              :   {
     337          462 :     const auto tensor = stack[i++].toTensor();
     338          462 :     *param = Tensor(tensor, tensor.dim() - Tensor(*param).base_dim());
     339          462 :   }
     340              : 
     341              :   // Drop the input variables from the stack
     342          394 :   jit::drop(stack, params.size());
     343          394 : }
     344              : 
     345              : jit::Stack
     346         6366 : ParameterStore::collect_parameter_stack() const
     347              : {
     348         6366 :   const auto & params = _object->host<ParameterStore>()->named_parameters();
     349         6366 :   jit::Stack stack;
     350         6366 :   stack.reserve(params.size());
     351        21861 :   for (auto && [name, param] : params)
     352        15495 :     stack.emplace_back(Tensor(*param));
     353         6366 :   return stack;
     354            0 : }
     355              : } // namespace neml2
        

Generated by: LCOV version 2.0-1