NEML2 2.0.0
Loading...
Searching...
No Matches
Model.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/base/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
29
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/solvers/NonlinearSystem.h"
34
35namespace neml2
36{
44class Model : public std::enable_shared_from_this<Model>,
45 public Data,
46 public ParameterStore,
47 public VariableStore,
48 public NonlinearSystem,
49 public DependencyDefinition<VariableName>,
51{
52public:
54
60 Model(const OptionSet & options);
61
63 virtual void to(const torch::TensorOptions & options);
64
65 void diagnose(std::vector<Diagnosis> &) const override;
66
68 virtual bool is_nonlinear_system() const { return _nonlinear_system; }
69
71 const std::vector<Model *> & registered_models() const { return _registered_models; }
73 Model * registered_model(const std::string & name) const;
74
76 std::set<VariableName> consumed_items() const override;
78 std::set<VariableName> provided_items() const override;
79
80 void clear_input() override;
81 void clear_output() override;
82 void zero_input() override;
83 void zero_output() override;
84
86 void request_AD(VariableBase & y, const VariableBase & u);
87
89 void request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2);
90
92 virtual void value();
94 virtual void value_and_dvalue();
96 virtual void dvalue();
98 virtual void value_and_dvalue_and_d2value();
100 virtual void d2value();
102 virtual void dvalue_and_d2value();
103
105 virtual ValueMap value(const ValueMap & in);
106
108 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(const ValueMap & in);
109
111 virtual DerivMap dvalue(const ValueMap & in);
112
114 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
116
118 virtual SecDerivMap d2value(const ValueMap & in);
119
121 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(const ValueMap & in);
122
124 friend class ParameterStore;
125
127 friend class ComposedModel;
128
129protected:
130 void setup() override;
131 virtual void link_input_variables();
132 virtual void link_input_variables(Model * submodel);
133 virtual void link_output_variables();
134 virtual void link_output_variables(Model * submodel);
135
149 virtual void request_AD() {}
150
152 void diagnose_nl_sys(std::vector<Diagnosis> & diagnoses) const;
153
155 virtual void set_value(bool out, bool dout_din, bool d2out_din2) = 0;
156
169 template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Model, T>>>
170 T & register_model(const std::string & name, bool nonlinear = false, bool merge_input = true)
171 {
172 neml_assert(name != this->name(),
173 "Model named '",
174 this->name(),
175 "' is trying to register itself as a sub-model. This is not allowed.");
176
178 extra_opts.set<NEML2Object *>("_host") = host();
179 extra_opts.set<bool>("_nonlinear_system") = nonlinear;
180
181 auto model = Factory::get_object_ptr<Model>("Models", name, extra_opts);
182
183 if (merge_input)
184 for (auto && [name, var] : model->input_variables())
186
187 _registered_models.push_back(model.get());
188 return *(std::dynamic_pointer_cast<T>(model));
189 }
190
191 void set_guess(const Sol<false> &) override;
192
193 void assemble(Res<false> *, Jac<false> *) override;
194
196 std::vector<Model *> _registered_models;
197
198private:
201 bool AD_need_value(bool dout, bool d2out) const;
202
204 void enable_AD();
205
207 void extract_AD_derivatives(bool dout, bool d2out);
208
210 bool _nonlinear_system;
211
213 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
214 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
215 _ad_secderivs;
216 std::set<VariableBase *> _ad_args;
217};
218
219std::ostream & operator<<(std::ostream & os, const Model & model);
220} // namespace neml2
Definition ComposedModel.h:35
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
Definition Data.h:36
Definition DependencyDefinition.h:40
Interface for object making diagnostics about common setup errors.
Definition DiagnosticsInterface.h:47
The base class for all constitutive models.
Definition Model.h:51
void clear_input() override
Definition Model.cxx:171
virtual void dvalue()
Evalute the derivative.
Definition Model.cxx:277
void diagnose(std::vector< Diagnosis > &) const override
Check for common problems.
Definition Model.cxx:68
virtual void dvalue_and_d2value()
Evalute the first and second derivatives.
Definition Model.cxx:293
virtual void value_and_dvalue_and_d2value()
Evalute the model and compute its first and second derivatives.
Definition Model.cxx:285
virtual void link_input_variables()
Definition Model.cxx:123
const std::vector< Model * > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:71
std::vector< Model * > _registered_models
Models this model may use during its evaluation.
Definition Model.h:196
virtual void d2value()
Evalute the second derivatives.
Definition Model.cxx:301
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:327
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:341
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:170
void zero_input() override
Definition Model.cxx:187
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:68
virtual void link_output_variables()
Definition Model.cxx:140
virtual void to(const torch::TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:58
void setup() override
Setup this object.
Definition Model.cxx:109
virtual void value()
Evalute the model.
Definition Model.cxx:263
void clear_output() override
Definition Model.cxx:179
static OptionSet expected_options()
Definition Model.cxx:33
virtual void value_and_dvalue()
Evalute the model and compute its derivative.
Definition Model.cxx:269
void diagnose_nl_sys(std::vector< Diagnosis > &diagnoses) const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:84
virtual void request_AD()
Definition Model.h:149
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:47
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:334
Model * registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:309
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:320
void zero_output() override
Definition Model.cxx:195
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:38
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:70
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:95
Definition of a nonlinear system of equations.
Definition NonlinearSystem.h:37
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:85
Interface for object which can store parameters.
Definition ParameterStore.h:46
Base class of variable.
Definition Variable.h:47
Definition VariableStore.h:40
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.h:156
Definition CrossRef.cxx:31
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types.h:34
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types.h:35
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types.h:36
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64