NEML2 2.0.0
Loading...
Searching...
No Matches
Model.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
29
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/solvers/NonlinearSystem.h"
34#include "neml2/models/NonlinearParameter.h"
35#include "neml2/tensors/jit.h"
36
37// These headers are not directly used by Model, but are included here so that derived classes do
38// not have to include them separately. This is a convenience for the user, and is a reasonable
39// choice since these headers are light and bring in little dependency.
40#include "neml2/base/LabeledAxis.h"
41#include "neml2/models/Variable.h"
42#include "neml2/models/Derivative.h"
43
44namespace neml2
45{
46class Model;
47
54std::shared_ptr<Model> load_model(const std::filesystem::path & path, const std::string & mname);
55
63class Model : public std::enable_shared_from_this<Model>,
64 public Data,
65 public ParameterStore,
66 public VariableStore,
67 public NonlinearSystem,
68 public DependencyDefinition<VariableName>,
70{
71public:
78 {
79 std::vector<Size> dynamic_dims;
80 std::vector<TensorShape> intmd_shapes;
81 at::DispatchKey dispatch_key;
82 bool operator==(const EvaluationSchema & other) const;
83 bool operator!=(const EvaluationSchema & other) const;
84 bool operator<(const EvaluationSchema & other) const;
85 };
86
88
94 Model(const OptionSet & options);
95
96 void setup() override;
97
98 void diagnose() const override;
99
101 virtual bool defines_values() const { return _defines_value; }
102
104 virtual bool defines_derivatives() const { return _defines_dvalue; }
105
107 virtual bool defines_second_derivatives() const { return _defines_d2value; }
108
110 virtual bool is_nonlinear_system() const { return _nonlinear_system; }
111
113 virtual bool is_jit_enabled() const { return _jit; }
114
116 virtual void to(const TensorOptions & options);
117
119 const std::vector<std::shared_ptr<Model>> & registered_models() const
120 {
121 return _registered_models;
122 }
124 std::shared_ptr<Model> registered_model(const std::string & name) const;
125
127 void register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param);
128
130 bool has_nl_param(bool recursive = false) const;
131
138 const VariableBase * nl_param(const std::string &) const;
139
141 virtual std::map<std::string, NonlinearParameter>
142 named_nonlinear_parameters(bool recursive = false) const;
143
145 std::set<VariableName> consumed_items() const override;
147 std::set<VariableName> provided_items() const override;
148
150 void request_AD(VariableBase & y, const VariableBase & u);
151
153 void request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2);
154
155 void clear_input() override;
156 void clear_output() override;
157 void zero_undefined_input() override;
158
167 void forward_maybe_jit(bool out, bool dout, bool d2out);
168
170 std::string variable_name_lookup(const ATensor & var) const;
171
173 virtual ValueMap value(const ValueMap & in);
174
176 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(const ValueMap & in);
177
179 virtual DerivMap dvalue(const ValueMap & in);
180
182 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
184
186 virtual SecDerivMap d2value(const ValueMap & in);
187
189 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(const ValueMap & in);
190
192 friend class ParameterStore;
193
195 friend class ComposedModel;
196
197protected:
198 void diagnostic_assert_state(const VariableBase & v) const;
199 void diagnostic_assert_old_state(const VariableBase & v) const;
200 void diagnostic_assert_force(const VariableBase & v) const;
201 void diagnostic_assert_old_force(const VariableBase & v) const;
202 void diagnostic_assert_residual(const VariableBase & v) const;
203 void diagnostic_check_input_variable(const VariableBase & v) const;
205
207 void diagnose_nl_sys() const;
208
209 virtual void link_input_variables();
210 virtual void link_input_variables(Model * submodel);
211 virtual void link_output_variables();
212 virtual void link_output_variables(Model * submodel);
213
215 void check_precision() const;
216
230 virtual void request_AD() {}
231
233 void forward(bool out, bool dout, bool d2out);
234
236 virtual void set_value(bool out, bool dout_din, bool d2out_din2) = 0;
237
250 template <typename T = Model, typename = typename std::enable_if_t<std::is_base_of_v<Model, T>>>
251 T & register_model(const std::string & name, bool nonlinear = false, bool merge_input = true)
252 {
253 auto model_name =
254 input_options().contains(name) ? input_options().get<std::string>(name) : name;
255 if (model_name == this->name())
256 throw SetupException("Model named '" + this->name() +
257 "' is trying to register itself as a sub-model. This is not allowed.");
258
259 OptionSet extra_opts;
260 extra_opts.set<NEML2Object *>("_host") = host();
261 extra_opts.set<bool>("_nonlinear_system") = nonlinear;
262
263 if (!host()->factory())
264 throw SetupException("Internal error: Host object '" + host()->name() +
265 "' does not have a factory set.");
266 auto model = host()->factory()->get_object<T>("Models", model_name, extra_opts);
267 if (std::find(_registered_models.begin(), _registered_models.end(), model) !=
268 _registered_models.end())
269 throw SetupException("Model named '" + model_name + "' has already been registered.");
270
271 if (merge_input)
272 for (auto && [name, var] : model->input_variables())
274
275 _registered_models.push_back(model);
276 return *model;
277 }
278
279 void assign_input_stack(jit::Stack & stack);
280
281 jit::Stack collect_input_stack() const;
282
283 void set_guess(const Sol<false> &) override;
284
285 void assemble(Res<false> *, Jac<false> *) override;
286
288 std::vector<std::shared_ptr<Model>> _registered_models;
289
290private:
291 template <typename T>
292 void forward_helper(T && in, bool out, bool dout, bool d2out)
293 {
295 assign_input(std::forward<T>(in));
297 forward_maybe_jit(out, dout, d2out);
298 }
299
301 void enable_AD();
302
304 void extract_AD_derivatives(bool dout, bool d2out);
305
307 std::size_t forward_operator_index(bool out, bool dout, bool d2out) const;
308
310 EvaluationSchema calculate_eval_schema() const;
311
314 bool _defines_value;
315 bool _defines_dvalue;
316 bool _defines_d2value;
318
320 bool _nonlinear_system;
321
323 std::map<std::string, NonlinearParameter> _nl_params;
324
327 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
328 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
329 _ad_secderivs;
330 std::set<VariableBase *> _ad_args;
332
334 const bool _jit;
335
337 const bool _production;
338
355 std::array<std::map<EvaluationSchema, std::unique_ptr<jit::GraphFunction>>, 8> _traced_functions;
356
358 std::array<std::map<EvaluationSchema, std::unique_ptr<jit::GraphFunction>>, 8>
359 _traced_functions_nl_sys;
360};
361
362std::ostream & operator<<(std::ostream & os, const Model & model);
363} // namespace neml2
Definition ComposedModel.h:33
Definition Data.h:40
Definition DependencyDefinition.h:40
Interface for object making diagnostics about common setup errors.
Definition DiagnosticsInterface.h:83
The base class for all constitutive models.
Definition Model.h:70
void clear_input() override
Definition Model.cxx:329
std::vector< std::shared_ptr< Model > > _registered_models
Models this model may use during its evaluation.
Definition Model.h:288
virtual void to(const TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:125
void check_precision() const
Check the current default precision and warn if it's not double precision.
Definition Model.cxx:515
void diagnostic_assert_force(const VariableBase &v) const
Definition Model.cxx:210
void zero_undefined_input() override
Fill undefined input variables with zeros.
Definition Model.cxx:345
virtual std::tuple< ValueMap, DerivMap > value_and_dvalue(const ValueMap &in)
Convenient shortcut to construct and return the model value and its derivative.
Definition Model.cxx:539
virtual void link_input_variables()
Definition Model.cxx:269
void register_nonlinear_parameter(const std::string &pname, const NonlinearParameter &param)
Register a nonlinear parameter.
Definition Model.cxx:609
virtual bool defines_derivatives() const
Whether this model defines first derivatives.
Definition Model.h:104
bool has_nl_param(bool recursive=false) const
Whether this parameter store has any nonlinear parameter.
Definition Model.cxx:619
void diagnostic_check_output_variable(const VariableBase &v) const
Definition Model.cxx:252
virtual bool defines_values() const
Whether this model defines output values.
Definition Model.h:101
const std::vector< std::shared_ptr< Model > > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:119
void forward_maybe_jit(bool out, bool dout, bool d2out)
Forward operator with jit.
Definition Model.cxx:417
virtual SecDerivMap d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's second derivative.
Definition Model.cxx:587
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:667
virtual bool defines_second_derivatives() const
Whether this model defines second derivatives.
Definition Model.h:107
void diagnostic_assert_state(const VariableBase &v) const
Definition Model.cxx:197
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:714
void forward(bool out, bool dout, bool d2out)
Forward operator without jit.
Definition Model.cxx:380
virtual std::tuple< ValueMap, DerivMap, SecDerivMap > value_and_dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's value, first and second derivative.
Definition Model.cxx:562
virtual ValueMap value(const ValueMap &in)
Convenient shortcut to construct and return the model value.
Definition Model.cxx:528
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:251
virtual std::map< std::string, NonlinearParameter > named_nonlinear_parameters(bool recursive=false) const
Get all nonlinear parameters.
Definition Model.cxx:638
void diagnostic_check_input_variable(const VariableBase &v) const
Definition Model.cxx:230
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:110
virtual std::tuple< DerivMap, SecDerivMap > dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's first and second derivative.
Definition Model.cxx:575
std::string variable_name_lookup(const ATensor &var) const
Look up the name of a variable in the traced graph.
Definition Model.cxx:483
std::shared_ptr< Model > registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:598
void diagnostic_assert_old_state(const VariableBase &v) const
Definition Model.cxx:203
virtual void link_output_variables()
Definition Model.cxx:286
jit::Stack collect_input_stack() const
Definition Model.cxx:694
void diagnostic_assert_old_force(const VariableBase &v) const
Definition Model.cxx:216
void setup() override
Setup this object.
Definition Model.cxx:139
void diagnose() const override
Check for common problems.
Definition Model.cxx:151
void clear_output() override
Definition Model.cxx:337
static OptionSet expected_options()
Definition Model.cxx:75
virtual DerivMap dvalue(const ValueMap &in)
Convenient shortcut to construct and return the derivative.
Definition Model.cxx:551
void assign_input_stack(jit::Stack &stack)
Definition Model.cxx:676
void diagnostic_assert_residual(const VariableBase &v) const
Definition Model.cxx:223
virtual void request_AD()
Definition Model.h:230
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:109
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:707
virtual bool is_jit_enabled() const
Whether JIT is enabled.
Definition Model.h:113
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
Definition Model.cxx:632
void diagnose_nl_sys() const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:173
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:658
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:51
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:83
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:150
Factory * factory() const
Get the factory that created this object.
Definition NEML2Object.h:92
const OptionSet & input_options() const
Definition NEML2Object.h:69
Definition of a nonlinear system of equations.
Definition NonlinearSystem.h:59
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:51
T get(const std::string &) const
Definition OptionSet.h:242
bool contains(const std::string &) const
Definition OptionSet.cxx:47
T & set(const std::string &)
Definition OptionSet.h:254
Interface for object which can store parameters.
Definition ParameterStore.h:53
Definition errors.h:50
Base class of variable.
Definition VariableBase.h:53
Definition VariableStore.h:46
void assign_input(const ValueMap &vals, bool assembly=false)
Definition VariableStore.cxx:345
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.cxx:155
Definition DiagnosticsInterface.cxx:30
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
at::Tensor ATensor
Definition types.h:38
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types_fwd.h:35
std::shared_ptr< Model > load_model(const std::filesystem::path &path, const std::string &mname)
A convenient function to load an input file and get a model.
Definition Model.cxx:45
c10::TensorOptions TensorOptions
Definition types.h:60
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
Schema for the traced forward operators.
Definition Model.h:78
bool operator==(const EvaluationSchema &other) const
Definition Model.cxx:52
at::DispatchKey dispatch_key
Definition Model.h:81
bool operator!=(const EvaluationSchema &other) const
Definition Model.cxx:59
std::vector< Size > dynamic_dims
Definition Model.h:79
bool operator<(const EvaluationSchema &other) const
Definition Model.cxx:65
std::vector< TensorShape > intmd_shapes
Definition Model.h:80
Nonlinear parameter.
Definition NonlinearParameter.h:51