NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
Model.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
29
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/solvers/NonlinearSystem.h"
34#include "neml2/models/NonlinearParameter.h"
35
36// These headers are not directly used by Model, but are included here so that derived classes do
37// not have to include them separately. This is a convenience for the user, and is a reasonable
38// choice since these headers are light and bring in little dependency.
39#include "neml2/base/TensorName.h"
40#include "neml2/base/LabeledAxis.h"
41#include "neml2/tensors/TensorValue.h"
42#include "neml2/models/Variable.h"
43
44namespace neml2
45{
46class Model;
47
55Model & get_model(const std::string & mname);
56
66Model & load_model(const std::filesystem::path & path, const std::string & mname);
67
78Model & reload_model(const std::filesystem::path & path, const std::string & mname);
79
81void check_precision();
82
90class Model : public std::enable_shared_from_this<Model>,
91 public Data,
92 public ParameterStore,
93 public VariableStore,
94 public NonlinearSystem,
95 public DependencyDefinition<VariableName>,
97{
98public:
105 {
106 std::vector<Size> batch_dims;
107 at::DispatchKey dispatch_key;
108 bool operator==(const TraceSchema & other) const;
109 bool operator<(const TraceSchema & other) const;
110 };
111
113
119 Model(const OptionSet & options);
120
121 void setup() override;
122
123 void diagnose() const override;
124
126 virtual bool defines_values() const { return _defines_value; }
127
129 virtual bool defines_derivatives() const { return _defines_dvalue; }
130
132 virtual bool defines_second_derivatives() const { return _defines_d2value; }
133
135 virtual bool is_nonlinear_system() const { return _nonlinear_system; }
136
138 virtual bool is_jit_enabled() const { return _jit; }
139
141 virtual void to(const TensorOptions & options);
142
144 const std::vector<Model *> & registered_models() const { return _registered_models; }
146 Model * registered_model(const std::string & name) const;
147
149 void register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param);
150
152 bool has_nl_param(bool recursive = false) const;
153
160 const VariableBase * nl_param(const std::string &) const;
161
163 virtual std::map<std::string, NonlinearParameter>
164 named_nonlinear_parameters(bool recursive = false) const;
165
167 std::set<VariableName> consumed_items() const override;
169 std::set<VariableName> provided_items() const override;
170
172 void request_AD(VariableBase & y, const VariableBase & u);
173
175 void request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2);
176
178 void forward(bool out, bool dout, bool d2out);
179
188 void forward_maybe_jit(bool out, bool dout, bool d2out);
189
191 std::string variable_name_lookup(const ATensor & var);
192
194 virtual ValueMap value(const ValueMap & in);
195 virtual ValueMap value(ValueMap && in);
196
198 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(const ValueMap & in);
199 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(ValueMap && in);
200
202 virtual DerivMap dvalue(const ValueMap & in);
203 virtual DerivMap dvalue(ValueMap && in);
204
206 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
208 virtual std::tuple<ValueMap, DerivMap, SecDerivMap> value_and_dvalue_and_d2value(ValueMap && in);
209
211 virtual SecDerivMap d2value(const ValueMap & in);
212 virtual SecDerivMap d2value(ValueMap && in);
213
215 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(const ValueMap & in);
216 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(ValueMap && in);
217
219 friend class ParameterStore;
220
222 friend class ComposedModel;
223
224protected:
225 void diagnostic_assert_state(const VariableBase & v) const;
226 void diagnostic_assert_old_state(const VariableBase & v) const;
227 void diagnostic_assert_force(const VariableBase & v) const;
228 void diagnostic_assert_old_force(const VariableBase & v) const;
229 void diagnostic_assert_residual(const VariableBase & v) const;
230 void diagnostic_check_input_variable(const VariableBase & v) const;
232
234 void diagnose_nl_sys() const;
235
236 virtual void link_input_variables();
237 virtual void link_input_variables(Model * submodel);
238 virtual void link_output_variables();
239 virtual void link_output_variables(Model * submodel);
240
241 void clear_input() override;
242 void clear_output() override;
243 void zero_input() override;
244 void zero_output() override;
245
259 virtual void request_AD() {}
260
262 virtual void set_value(bool out, bool dout_din, bool d2out_din2) = 0;
263
276 template <typename T = Model, typename = typename std::enable_if_t<std::is_base_of_v<Model, T>>>
277 T & register_model(const std::string & name, bool nonlinear = false, bool merge_input = true)
278 {
279 if (name == this->name())
280 throw SetupException("Model named '" + this->name() +
281 "' is trying to register itself as a sub-model. This is not allowed.");
282
283 OptionSet extra_opts;
284 extra_opts.set<NEML2Object *>("_host") = host();
285 extra_opts.set<bool>("_nonlinear_system") = nonlinear;
286
287 auto model = Factory::get_object_ptr<T>("Models", name, extra_opts);
288 if (std::find(_registered_models.begin(), _registered_models.end(), model.get()) !=
289 _registered_models.end())
290 throw SetupException("Model named '" + name + "' has already been registered.");
291
292 if (merge_input)
293 for (auto && [name, var] : model->input_variables())
295
296 _registered_models.push_back(model.get());
297 return *model;
298 }
299
300 void assign_input_stack(jit::Stack & stack);
301
302 jit::Stack collect_input_stack() const;
303
304 void set_guess(const Sol<false> &) override;
305
306 void assemble(Res<false> *, Jac<false> *) override;
307
309 std::vector<Model *> _registered_models;
310
311private:
312 template <typename T>
313 void forward_helper(T && in, bool out, bool dout, bool d2out)
314 {
316 zero_input();
317 assign_input(std::forward<T>(in));
318 zero_output();
319 forward_maybe_jit(out, dout, d2out);
320 }
321
324 bool AD_need_value(bool dout, bool d2out) const;
325
327 void enable_AD();
328
330 void extract_AD_derivatives(bool dout, bool d2out);
331
333 std::size_t forward_operator_index(bool out, bool dout, bool d2out) const;
334
336 TraceSchema compute_trace_schema() const;
337
340 bool _defines_value;
341 bool _defines_dvalue;
342 bool _defines_d2value;
344
346 bool _nonlinear_system;
347
349 std::map<std::string, NonlinearParameter> _nl_params;
350
353 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
354 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
355 _ad_secderivs;
356 std::set<VariableBase *> _ad_args;
358
360 const bool _jit;
361
363 const bool _production;
364
381 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8> _traced_functions;
382
384 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8>
385 _traced_functions_nl_sys;
386};
387
388std::ostream & operator<<(std::ostream & os, const Model & model);
389} // namespace neml2
Data(const OptionSet &options)
Construct a new Data object.
Definition Data.cxx:38
static std::shared_ptr< T > get_object_ptr(const std::string &section, const std::string &name, const OptionSet &additional_options=OptionSet(), bool force_create=true)
Retrive an object pointer under the given section with the given object name.
Definition Factory.h:165
The base class for all constitutive models.
Definition Model.h:97
friend class ComposedModel
ComposedModel's set_value need to call submodel's set_value.
Definition Model.h:222
void clear_input() override
Definition Model.cxx:326
virtual void to(const TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:126
void diagnostic_assert_force(const VariableBase &v) const
Definition Model.cxx:207
virtual std::tuple< ValueMap, DerivMap > value_and_dvalue(const ValueMap &in)
Convenient shortcut to construct and return the model value and its derivative.
Definition Model.cxx:527
virtual void link_input_variables()
Definition Model.cxx:266
const std::vector< Model * > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:144
std::vector< Model * > _registered_models
Models this model may use during its evaluation.
Definition Model.h:309
void register_nonlinear_parameter(const std::string &pname, const NonlinearParameter &param)
Register a nonlinear parameter.
Definition Model.cxx:656
virtual bool defines_derivatives() const
Whether this model defines first derivatives.
Definition Model.h:129
bool has_nl_param(bool recursive=false) const
Whether this parameter store has any nonlinear parameter.
Definition Model.cxx:666
void diagnostic_check_output_variable(const VariableBase &v) const
Definition Model.cxx:249
virtual bool defines_values() const
Whether this model defines output values.
Definition Model.h:126
void forward_maybe_jit(bool out, bool dout, bool d2out)
Forward operator with jit.
Definition Model.cxx:407
virtual SecDerivMap d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's second derivative.
Definition Model.cxx:623
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:711
friend class ParameterStore
Declaration of nonlinear parameters may require manipulation of input.
Definition Model.h:219
virtual bool defines_second_derivatives() const
Whether this model defines second derivatives.
Definition Model.h:132
void diagnostic_assert_state(const VariableBase &v) const
Definition Model.cxx:194
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:756
void forward(bool out, bool dout, bool d2out)
Forward operator without jit.
Definition Model.cxx:378
virtual std::tuple< ValueMap, DerivMap, SecDerivMap > value_and_dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's value, first and second derivative.
Definition Model.cxx:573
virtual ValueMap value(const ValueMap &in)
Convenient shortcut to construct and return the model value.
Definition Model.cxx:505
std::string variable_name_lookup(const ATensor &var)
Look up the name of a variable in the traced graph.
Definition Model.cxx:462
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:277
virtual std::map< std::string, NonlinearParameter > named_nonlinear_parameters(bool recursive=false) const
Get all nonlinear parameters.
Definition Model.cxx:685
void zero_input() override
Definition Model.cxx:342
void diagnostic_check_input_variable(const VariableBase &v) const
Definition Model.cxx:227
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:135
virtual std::tuple< DerivMap, SecDerivMap > dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's first and second derivative.
Definition Model.cxx:599
void diagnostic_assert_old_state(const VariableBase &v) const
Definition Model.cxx:200
virtual void link_output_variables()
Definition Model.cxx:283
jit::Stack collect_input_stack() const
Definition Model.cxx:736
void diagnostic_assert_old_force(const VariableBase &v) const
Definition Model.cxx:213
void setup() override
Setup this object.
Definition Model.cxx:140
void diagnose() const override
Check for common problems.
Definition Model.cxx:154
void clear_output() override
Definition Model.cxx:334
static OptionSet expected_options()
Definition Model.cxx:76
virtual DerivMap dvalue(const ValueMap &in)
Convenient shortcut to construct and return the derivative.
Definition Model.cxx:551
void assign_input_stack(jit::Stack &stack)
Definition Model.cxx:718
void diagnostic_assert_residual(const VariableBase &v) const
Definition Model.cxx:220
virtual void request_AD()
Definition Model.h:259
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:110
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:749
Model * registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:645
virtual bool is_jit_enabled() const
Whether JIT is enabled.
Definition Model.h:138
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
Definition Model.cxx:679
void diagnose_nl_sys() const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:170
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:704
void zero_output() override
Definition Model.cxx:350
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:42
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:74
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:99
NonlinearSystem(const NonlinearSystem &)=default
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:52
T & set(const std::string &)
Definition OptionSet.h:273
Definition errors.h:50
Base class of variable.
Definition Variable.h:52
VariableStore(Model *object)
Definition VariableStore.cxx:36
void assign_input(const ValueMap &vals)
Definition VariableStore.cxx:257
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.cxx:120
Definition DiagnosticsInterface.cxx:30
Model & load_model(const std::filesystem::path &path, const std::string &mname)
A convenient function to load an input file and get a model.
Definition Model.cxx:48
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
Model & get_model(const std::string &mname)
A convenient function to manufacture a neml2::Model.
Definition Model.cxx:41
at::Tensor ATensor
Definition types.h:42
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types_fwd.h:35
Model & reload_model(const std::filesystem::path &path, const std::string &mname)
Similar to neml2::load_model, but additionally clear the Factory before loading the model,...
Definition Model.cxx:55
c10::TensorOptions TensorOptions
Definition types.h:63
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:32
void check_precision()
Check the current default precision and warn if it's not double precision.
Definition Model.cxx:492
Schema for the traced forward operators.
Definition Model.h:105
bool operator==(const TraceSchema &other) const
Definition Model.cxx:62
std::vector< Size > batch_dims
Definition Model.h:106
at::DispatchKey dispatch_key
Definition Model.h:107
bool operator<(const TraceSchema &other) const
Definition Model.cxx:68
Nonlinear parameter.
Definition NonlinearParameter.h:49