NEML2 2.0.0
Loading...
Searching...
No Matches
Model.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
29
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/solvers/NonlinearSystem.h"
34#include "neml2/models/NonlinearParameter.h"
35
36// These headers are not directly used by Model, but are included here so that derived classes do
37// not have to include them separately. This is a convenience for the user, and is a reasonable
38// choice since these headers are light and bring in little dependency.
39#include "neml2/base/LabeledAxis.h"
40#include "neml2/models/Variable.h"
41
42namespace neml2
43{
44class Model;
45
52std::shared_ptr<Model> load_model(const std::filesystem::path & path, const std::string & mname);
53
61class Model : public std::enable_shared_from_this<Model>,
62 public Data,
63 public ParameterStore,
64 public VariableStore,
65 public NonlinearSystem,
66 public DependencyDefinition<VariableName>,
68{
69public:
76 {
77 std::vector<Size> batch_dims;
78 at::DispatchKey dispatch_key;
79 bool operator==(const TraceSchema & other) const;
80 bool operator<(const TraceSchema & other) const;
81 };
82
84
90 Model(const OptionSet & options);
91
92 void setup() override;
93
94 void diagnose() const override;
95
97 virtual bool defines_values() const { return _defines_value; }
98
100 virtual bool defines_derivatives() const { return _defines_dvalue; }
101
103 virtual bool defines_second_derivatives() const { return _defines_d2value; }
104
106 virtual bool is_nonlinear_system() const { return _nonlinear_system; }
107
109 virtual bool is_jit_enabled() const { return _jit; }
110
112 virtual void to(const TensorOptions & options);
113
115 const std::vector<std::shared_ptr<Model>> & registered_models() const
116 {
117 return _registered_models;
118 }
120 std::shared_ptr<Model> registered_model(const std::string & name) const;
121
123 void register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param);
124
126 bool has_nl_param(bool recursive = false) const;
127
134 const VariableBase * nl_param(const std::string &) const;
135
137 virtual std::map<std::string, NonlinearParameter>
138 named_nonlinear_parameters(bool recursive = false) const;
139
141 std::set<VariableName> consumed_items() const override;
143 std::set<VariableName> provided_items() const override;
144
146 void request_AD(VariableBase & y, const VariableBase & u);
147
149 void request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2);
150
152 void forward(bool out, bool dout, bool d2out);
153
162 void forward_maybe_jit(bool out, bool dout, bool d2out);
163
165 std::string variable_name_lookup(const ATensor & var);
166
168 virtual ValueMap value(const ValueMap & in);
169 virtual ValueMap value(ValueMap && in);
170
172 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(const ValueMap & in);
173 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(ValueMap && in);
174
176 virtual DerivMap dvalue(const ValueMap & in);
177 virtual DerivMap dvalue(ValueMap && in);
178
180 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
182 virtual std::tuple<ValueMap, DerivMap, SecDerivMap> value_and_dvalue_and_d2value(ValueMap && in);
183
185 virtual SecDerivMap d2value(const ValueMap & in);
186 virtual SecDerivMap d2value(ValueMap && in);
187
189 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(const ValueMap & in);
190 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(ValueMap && in);
191
193 friend class ParameterStore;
194
196 friend class ComposedModel;
197
198protected:
199 void diagnostic_assert_state(const VariableBase & v) const;
200 void diagnostic_assert_old_state(const VariableBase & v) const;
201 void diagnostic_assert_force(const VariableBase & v) const;
202 void diagnostic_assert_old_force(const VariableBase & v) const;
203 void diagnostic_assert_residual(const VariableBase & v) const;
204 void diagnostic_check_input_variable(const VariableBase & v) const;
206
208 void diagnose_nl_sys() const;
209
210 virtual void link_input_variables();
211 virtual void link_input_variables(Model * submodel);
212 virtual void link_output_variables();
213 virtual void link_output_variables(Model * submodel);
214
215 void clear_input() override;
216 void clear_output() override;
217 void zero_input() override;
218 void zero_output() override;
219
221 void check_precision() const;
222
236 virtual void request_AD() {}
237
239 virtual void set_value(bool out, bool dout_din, bool d2out_din2) = 0;
240
253 template <typename T = Model, typename = typename std::enable_if_t<std::is_base_of_v<Model, T>>>
254 T & register_model(const std::string & name, bool nonlinear = false, bool merge_input = true)
255 {
256 auto model_name =
257 input_options().contains(name) ? input_options().get<std::string>(name) : name;
258 if (model_name == this->name())
259 throw SetupException("Model named '" + this->name() +
260 "' is trying to register itself as a sub-model. This is not allowed.");
261
262 OptionSet extra_opts;
263 extra_opts.set<NEML2Object *>("_host") = host();
264 extra_opts.set<bool>("_nonlinear_system") = nonlinear;
265
266 if (!host()->factory())
267 throw SetupException("Internal error: Host object '" + host()->name() +
268 "' does not have a factory set.");
269 auto model = host()->factory()->get_object<T>("Models", model_name, extra_opts);
270 if (std::find(_registered_models.begin(), _registered_models.end(), model) !=
271 _registered_models.end())
272 throw SetupException("Model named '" + model_name + "' has already been registered.");
273
274 if (merge_input)
275 for (auto && [name, var] : model->input_variables())
277
278 _registered_models.push_back(model);
279 return *model;
280 }
281
282 void assign_input_stack(jit::Stack & stack);
283
284 jit::Stack collect_input_stack() const;
285
286 void set_guess(const Sol<false> &) override;
287
288 void assemble(Res<false> *, Jac<false> *) override;
289
291 std::vector<std::shared_ptr<Model>> _registered_models;
292
293private:
294 template <typename T>
295 void forward_helper(T && in, bool out, bool dout, bool d2out)
296 {
298 zero_input();
299 assign_input(std::forward<T>(in));
300 zero_output();
301 forward_maybe_jit(out, dout, d2out);
302 }
303
306 bool AD_need_value(bool dout, bool d2out) const;
307
309 void enable_AD();
310
312 void extract_AD_derivatives(bool dout, bool d2out);
313
315 std::size_t forward_operator_index(bool out, bool dout, bool d2out) const;
316
318 TraceSchema compute_trace_schema() const;
319
322 bool _defines_value;
323 bool _defines_dvalue;
324 bool _defines_d2value;
326
328 bool _nonlinear_system;
329
331 std::map<std::string, NonlinearParameter> _nl_params;
332
335 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
336 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
337 _ad_secderivs;
338 std::set<VariableBase *> _ad_args;
340
342 const bool _jit;
343
345 const bool _production;
346
363 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8> _traced_functions;
364
366 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8>
367 _traced_functions_nl_sys;
368};
369
370std::ostream & operator<<(std::ostream & os, const Model & model);
371} // namespace neml2
Definition ComposedModel.h:33
Definition Data.h:40
Definition DependencyDefinition.h:40
Interface for object making diagnostics about common setup errors.
Definition DiagnosticsInterface.h:83
The base class for all constitutive models.
Definition Model.h:68
void clear_input() override
Definition Model.cxx:320
std::vector< std::shared_ptr< Model > > _registered_models
Models this model may use during its evaluation.
Definition Model.h:291
virtual void to(const TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:114
void check_precision() const
Check the current default precision and warn if it's not double precision.
Definition Model.cxx:498
void diagnostic_assert_force(const VariableBase &v) const
Definition Model.cxx:201
virtual std::tuple< ValueMap, DerivMap > value_and_dvalue(const ValueMap &in)
Convenient shortcut to construct and return the model value and its derivative.
Definition Model.cxx:533
virtual void link_input_variables()
Definition Model.cxx:260
void register_nonlinear_parameter(const std::string &pname, const NonlinearParameter &param)
Register a nonlinear parameter.
Definition Model.cxx:662
virtual bool defines_derivatives() const
Whether this model defines first derivatives.
Definition Model.h:100
bool has_nl_param(bool recursive=false) const
Whether this parameter store has any nonlinear parameter.
Definition Model.cxx:672
void diagnostic_check_output_variable(const VariableBase &v) const
Definition Model.cxx:243
virtual bool defines_values() const
Whether this model defines output values.
Definition Model.h:97
const std::vector< std::shared_ptr< Model > > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:115
void forward_maybe_jit(bool out, bool dout, bool d2out)
Forward operator with jit.
Definition Model.cxx:404
virtual SecDerivMap d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's second derivative.
Definition Model.cxx:629
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:718
virtual bool defines_second_derivatives() const
Whether this model defines second derivatives.
Definition Model.h:103
void diagnostic_assert_state(const VariableBase &v) const
Definition Model.cxx:188
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:763
void forward(bool out, bool dout, bool d2out)
Forward operator without jit.
Definition Model.cxx:372
virtual std::tuple< ValueMap, DerivMap, SecDerivMap > value_and_dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's value, first and second derivative.
Definition Model.cxx:579
virtual ValueMap value(const ValueMap &in)
Convenient shortcut to construct and return the model value.
Definition Model.cxx:511
std::string variable_name_lookup(const ATensor &var)
Look up the name of a variable in the traced graph.
Definition Model.cxx:468
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:254
virtual std::map< std::string, NonlinearParameter > named_nonlinear_parameters(bool recursive=false) const
Get all nonlinear parameters.
Definition Model.cxx:691
void zero_input() override
Definition Model.cxx:336
void diagnostic_check_input_variable(const VariableBase &v) const
Definition Model.cxx:221
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:106
virtual std::tuple< DerivMap, SecDerivMap > dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's first and second derivative.
Definition Model.cxx:605
std::shared_ptr< Model > registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:651
void diagnostic_assert_old_state(const VariableBase &v) const
Definition Model.cxx:194
virtual void link_output_variables()
Definition Model.cxx:277
jit::Stack collect_input_stack() const
Definition Model.cxx:743
void diagnostic_assert_old_force(const VariableBase &v) const
Definition Model.cxx:207
void setup() override
Setup this object.
Definition Model.cxx:128
void diagnose() const override
Check for common problems.
Definition Model.cxx:142
void clear_output() override
Definition Model.cxx:328
static OptionSet expected_options()
Definition Model.cxx:64
virtual DerivMap dvalue(const ValueMap &in)
Convenient shortcut to construct and return the derivative.
Definition Model.cxx:557
void assign_input_stack(jit::Stack &stack)
Definition Model.cxx:725
void diagnostic_assert_residual(const VariableBase &v) const
Definition Model.cxx:214
virtual void request_AD()
Definition Model.h:236
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:98
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:756
virtual bool is_jit_enabled() const
Whether JIT is enabled.
Definition Model.h:109
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
Definition Model.cxx:685
void diagnose_nl_sys() const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:164
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:711
void zero_output() override
Definition Model.cxx:344
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:51
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:83
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:150
Factory * factory() const
Get the factory that created this object.
Definition NEML2Object.h:92
const OptionSet & input_options() const
Definition NEML2Object.h:69
Definition of a nonlinear system of equations.
Definition NonlinearSystem.h:59
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:51
T get(const std::string &) const
Definition OptionSet.h:242
bool contains(const std::string &) const
Definition OptionSet.cxx:47
T & set(const std::string &)
Definition OptionSet.h:254
Interface for object which can store parameters.
Definition ParameterStore.h:53
Definition errors.h:50
Base class of variable.
Definition Variable.h:52
Definition VariableStore.h:46
void assign_input(const ValueMap &vals)
Definition VariableStore.cxx:279
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.cxx:126
Definition DiagnosticsInterface.cxx:30
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types_fwd.h:35
at::Tensor ATensor
Definition types.h:38
std::shared_ptr< Model > load_model(const std::filesystem::path &path, const std::string &mname)
A convenient function to load an input file and get a model.
Definition Model.cxx:43
c10::TensorOptions TensorOptions
Definition types.h:60
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
Schema for the traced forward operators.
Definition Model.h:76
bool operator==(const TraceSchema &other) const
Definition Model.cxx:50
std::vector< Size > batch_dims
Definition Model.h:77
at::DispatchKey dispatch_key
Definition Model.h:78
bool operator<(const TraceSchema &other) const
Definition Model.cxx:56
Nonlinear parameter.
Definition NonlinearParameter.h:51