Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/models/ParameterStore.h"
26 : #include "neml2/models/Model.h"
27 : #include "neml2/models/InputParameter.h"
28 : #include "neml2/misc/assertions.h"
29 : #include "neml2/base/TensorName.h"
30 : #include "neml2/base/Parser.h"
31 : #include "neml2/base/Settings.h"
32 : #include "neml2/tensors/tensors.h"
33 : #include "neml2/tensors/TensorValue.h"
34 :
35 : namespace neml2
36 : {
37 432 : ParameterStore::ParameterStore(Model * object)
38 432 : : _object(object)
39 : {
40 432 : }
41 :
42 : void
43 91 : ParameterStore::send_parameters_to(const TensorOptions & options)
44 : {
45 130 : for (auto && [name, param] : _param_values)
46 39 : param->to_(options);
47 91 : }
48 :
49 : void
50 668 : ParameterStore::set_parameter(const std::string & name, const Tensor & value)
51 : {
52 668 : neml_assert(_object->host() == _object, "This method should only be called on the host model.");
53 668 : neml_assert(named_parameters().count(name), "There is no parameter named ", name);
54 668 : *named_parameters()[name] = value;
55 668 : }
56 :
57 : TensorValueBase &
58 7 : ParameterStore::get_parameter(const std::string & name)
59 : {
60 7 : neml_assert(_object->host() == _object, "This method should only be called on the host model.");
61 7 : neml_assert(_param_values.count(name), "Parameter named ", name, " does not exist.");
62 7 : return *_param_values[name];
63 : }
64 :
65 : void
66 0 : ParameterStore::set_parameters(const std::map<std::string, Tensor> & param_values)
67 : {
68 0 : for (const auto & [name, value] : param_values)
69 0 : set_parameter(name, value);
70 0 : }
71 :
72 : std::map<std::string, std::unique_ptr<TensorValueBase>> &
73 161471 : ParameterStore::named_parameters()
74 : {
75 161471 : neml_assert(_object->host() == _object,
76 : "named_parameters() should only be called on the host model.");
77 161471 : return _param_values;
78 : }
79 :
80 : template <typename T, typename>
81 : const T &
82 293 : ParameterStore::declare_parameter(const std::string & name, const T & rawval)
83 : {
84 293 : if (_object->host() != _object)
85 234 : return _object->host<ParameterStore>()->declare_parameter(
86 156 : _object->name() + _object->settings().parameter_name_separator() + name, rawval);
87 :
88 215 : TensorValueBase * base_ptr = nullptr;
89 :
90 : // If the parameter already exists, get it
91 215 : if (_param_values.count(name))
92 5 : base_ptr = &get_parameter(name);
93 : // If the parameter doesn't exist, create it
94 : else
95 : {
96 210 : auto val = std::make_unique<TensorValue<T>>(rawval);
97 210 : auto [it, success] = _param_values.emplace(name, std::move(val));
98 210 : base_ptr = it->second.get();
99 210 : }
100 :
101 215 : auto ptr = dynamic_cast<TensorValue<T> *>(base_ptr);
102 215 : neml_assert(ptr, "Internal error: Failed to cast parameter to a concrete type.");
103 215 : return ptr->value();
104 : }
105 :
106 : template <typename T>
107 : const T &
108 66 : resolve_tensor_name(const TensorName<T> & tn, Model * caller, const std::string & pname)
109 : {
110 66 : if (!caller)
111 0 : throw ParserException("A non-nullptr caller must be provided to resolve a tensor name");
112 :
113 : if constexpr (std::is_same_v<T, ATensor> || std::is_same_v<T, Tensor>)
114 : {
115 : throw ParserException("ATensr and Tensor cannot be resolved to a model output variable");
116 : }
117 : else
118 : {
119 : // When we retrieve a model, we want it to register its own parameters and buffers in the
120 : // host of the caller.
121 66 : OptionSet extra_opts;
122 132 : extra_opts.set<NEML2Object *>("_host") = caller->host();
123 :
124 : // The raw string is interpreted as a _variable specifier_ which takes three possible forms
125 : // 1. "model_name.variable_name"
126 : // 2. "model_name"
127 : // 3. "variable_name"
128 66 : std::shared_ptr<Model> provider = nullptr;
129 66 : VariableName var_name;
130 :
131 : // Split the raw string into tokens with the delimiter '.'
132 : // There must be either one or two tokens
133 66 : auto tokens = utils::split(tn.raw(), ".");
134 66 : if (tokens.size() != 1 && tokens.size() != 2)
135 0 : throw ParserException("Invalid variable specifier '" + tn.raw() +
136 : "'. It should take the form 'model_name', 'variable_name', or "
137 : "'model_name.variable_name'");
138 :
139 : // When there is only one token, it must be a model name, and the model must have one and only
140 : // one output variable.
141 66 : if (tokens.size() == 1)
142 : {
143 : // Try to parse it as a model name
144 66 : const auto & mname = tokens[0];
145 : try
146 : {
147 : // Get the model
148 189 : provider = caller->factory()->get_object<Model>(
149 : "Models", mname, extra_opts, /*force_create=*/false);
150 :
151 : // Apparently, the model must have one and only one output variable.
152 9 : const auto nvar = provider->output_axis().nvariable();
153 9 : if (nvar == 0)
154 0 : throw ParserException(
155 0 : "Invalid variable specifier '" + tn.raw() +
156 : "' (interpreted as model name). The model does not define any output variable.");
157 9 : if (nvar > 1)
158 0 : throw ParserException(
159 0 : "Invalid variable specifier '" + tn.raw() +
160 : "' (interpreted as model name). The model must have one and only one output "
161 : "variable. However, it has " +
162 : utils::stringify(nvar) +
163 : " output variables. To disambiguite, please specify the variable name using "
164 : "format 'model_name.variable_name'. The model's output axis is:\n" +
165 0 : utils::stringify(provider->output_axis()));
166 :
167 : // Retrieve the output variable
168 9 : var_name = provider->output_axis().variable_names()[0];
169 : }
170 : // Try to parse it as a variable name
171 114 : catch (const FactoryException & err_model)
172 : {
173 57 : auto success = utils::parse_<VariableName>(var_name, tokens[0]);
174 57 : if (!success)
175 0 : throw ParserException(
176 0 : "Invalid variable specifier '" + tn.raw() +
177 : "'. It should take the form 'model_name', 'variable_name', or "
178 : "'model_name.variable_name'. Since there is no '.' delimiter, it can either be a "
179 : "model name or a variable name. Interpreting it as a model name failed with error "
180 : "message: " +
181 0 : err_model.what() + ". It also cannot be parsed as a valid variable name.");
182 :
183 : // Create a dummy model that defines this parameter
184 57 : const auto obj_name = "__parameter_" + var_name.str() + "__";
185 57 : const auto obj_type = utils::demangle(typeid(T).name()).substr(7) + "InputParameter";
186 57 : auto options = InputParameter<T>::expected_options();
187 114 : options.template set<std::string>("name") = obj_name;
188 114 : options.template set<std::string>("type") = obj_type;
189 114 : options.template set<VariableName>("from") = var_name;
190 228 : options.template set<VariableName>("to") = var_name.with_suffix("_autogenerated");
191 57 : options.set("to").user_specified() = true;
192 57 : options.name() = obj_name;
193 57 : options.type() = obj_type;
194 171 : if (caller->factory()->input_file()["Models"].count(obj_name))
195 : {
196 0 : const auto & existing_options = caller->factory()->input_file()["Models"][obj_name];
197 0 : if (!options_compatible(existing_options, options))
198 0 : throw ParserException(
199 : "Option clash when declaring an input parameter. Existing options:\n" +
200 : utils::stringify(existing_options) + ". New options:\n" +
201 : utils::stringify(options));
202 : }
203 : else
204 171 : caller->factory()->input_file()["Models"][obj_name] = std::move(options);
205 :
206 : // Get the model
207 114 : provider = caller->factory()->get_object<Model>(
208 : "Models", obj_name, extra_opts, /*force_create=*/false);
209 :
210 : // Retrieve the output variable
211 57 : var_name = provider->output_axis().variable_names()[0];
212 57 : }
213 : }
214 : else
215 : {
216 : // The first token is the model name
217 0 : const auto & mname = tokens[0];
218 :
219 : // Get the model
220 0 : provider =
221 : caller->factory()->get_object<Model>("Models", mname, extra_opts, /*force_create=*/false);
222 :
223 : // The second token is the variable name
224 0 : auto success = utils::parse_<VariableName>(var_name, tokens[1]);
225 0 : if (!success)
226 0 : throw ParserException("Invalid variable specifier '" + tn.raw() + "'. '" + tokens[1] +
227 : "' cannot be parsed as a valid variable name.");
228 0 : if (!provider->output_axis().has_variable(var_name))
229 0 : throw ParserException("Invalid variable specifier '" + tn.raw() + "'. Model '" + mname +
230 : "' does not have an output variable named '" +
231 : utils::stringify(var_name) + "'");
232 : }
233 :
234 : // Declare the input variable
235 66 : caller->declare_input_variable<T>(var_name);
236 :
237 : // Get the variable
238 66 : const auto * var = &provider->output_variable(var_name);
239 66 : const auto * var_ptr = dynamic_cast<const Variable<T> *>(var);
240 66 : if (!var_ptr)
241 0 : throw ParserException("The variable specifier '" + tn.raw() +
242 : "' is valid, but the variable cannot be cast to type " +
243 : utils::demangle(typeid(T).name()));
244 :
245 : // For bookkeeping, the caller shall record the model that provides this variable
246 : // This is needed for two reasons:
247 : // 1. When the caller is composed with others, we need this information to automatically
248 : // bring in the provider.
249 : // 2. When the caller is sent to a different device/dtype, the caller needs to forward the
250 : // call to the provider.
251 66 : caller->register_nonlinear_parameter(pname, NonlinearParameter{provider, var_name, var_ptr});
252 :
253 : // Done!
254 132 : return var_ptr->value();
255 66 : }
256 66 : }
257 :
258 : template <typename T, typename>
259 : const T &
260 257 : ParameterStore::declare_parameter(const std::string & name,
261 : const TensorName<T> & tensorname,
262 : bool allow_nonlinear)
263 : {
264 257 : auto * factory = _object->factory();
265 257 : neml_assert(factory, "Internal error: factory != nullptr");
266 :
267 : try
268 : {
269 257 : return declare_parameter(name, tensorname.resolve(factory));
270 : }
271 132 : catch (const SetupException & err_tensor)
272 : {
273 66 : if (allow_nonlinear)
274 : try
275 : {
276 66 : return resolve_tensor_name(tensorname, _object, name);
277 : }
278 0 : catch (const SetupException & err_var)
279 : {
280 0 : throw ParserException(std::string(err_tensor.what()) +
281 : "\nAn additional attempt was made to interpret the tensor name as a "
282 : "variable specifier, but it failed with error message:" +
283 0 : err_var.what());
284 : }
285 : else
286 0 : throw ParserException(
287 0 : std::string(err_tensor.what()) +
288 : "\nThe tensor name cannot be interpreted as a variable specifier because variable "
289 : "coupling has not been implemented for this parameter. If this is intended, please "
290 : "consider opening an issue on the NEML2 GitHub repository.");
291 : }
292 : }
293 :
294 : template <typename T, typename>
295 : const T &
296 191 : ParameterStore::declare_parameter(const std::string & name,
297 : const std::string & input_option_name,
298 : bool allow_nonlinear)
299 : {
300 191 : if (_object->input_options().contains(input_option_name))
301 382 : return declare_parameter<T>(
302 573 : name, _object->input_options().get<TensorName<T>>(input_option_name), allow_nonlinear);
303 :
304 0 : throw NEMLException("Trying to register parameter named " + name + " from input option named " +
305 : input_option_name + " of type " + utils::demangle(typeid(T).name()) +
306 : ". Make sure you provided the correct parameter name, option name, and "
307 : "parameter type.");
308 : }
309 :
310 : #define PARAMETERSTORE_INTANTIATE_TENSORBASE(T) \
311 : template const T & ParameterStore::declare_parameter<T>(const std::string &, const T &)
312 : FOR_ALL_TENSORBASE(PARAMETERSTORE_INTANTIATE_TENSORBASE);
313 :
314 : #define PARAMETERSTORE_INTANTIATE_PRIMITIVETENSOR(T) \
315 : template const T & ParameterStore::declare_parameter<T>( \
316 : const std::string &, const TensorName<T> &, bool); \
317 : template const T & ParameterStore::declare_parameter<T>( \
318 : const std::string &, const std::string &, bool)
319 : FOR_ALL_PRIMITIVETENSOR(PARAMETERSTORE_INTANTIATE_PRIMITIVETENSOR);
320 :
321 : void
322 394 : ParameterStore::assign_parameter_stack(jit::Stack & stack)
323 : {
324 394 : const auto & params = _object->host<ParameterStore>()->named_parameters();
325 :
326 394 : neml_assert_dbg(stack.size() >= params.size(),
327 : "Stack size (",
328 394 : stack.size(),
329 : ") is smaller than the number of parameters in the model (",
330 394 : params.size(),
331 : ").");
332 :
333 : // Last n tensors in the stack are the parameters
334 394 : std::size_t i = stack.size() - params.size();
335 856 : for (auto && [name, param] : params)
336 : {
337 462 : const auto tensor = stack[i++].toTensor();
338 462 : *param = Tensor(tensor, tensor.dim() - Tensor(*param).base_dim());
339 462 : }
340 :
341 : // Drop the input variables from the stack
342 394 : jit::drop(stack, params.size());
343 394 : }
344 :
345 : jit::Stack
346 6366 : ParameterStore::collect_parameter_stack() const
347 : {
348 6366 : const auto & params = _object->host<ParameterStore>()->named_parameters();
349 6366 : jit::Stack stack;
350 6366 : stack.reserve(params.size());
351 21861 : for (auto && [name, param] : params)
352 15495 : stack.emplace_back(Tensor(*param));
353 6366 : return stack;
354 0 : }
355 : } // namespace neml2
|