NEML2 2.0.0
Loading...
Searching...
No Matches
Model.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
29
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/models/Variable.h"
34#include "neml2/solvers/NonlinearSystem.h"
35#include "neml2/models/NonlinearParameter.h"
36
37// These headers are not directly used by Model, but are included here so that derived classes do
38// not have to include them separately. This is a convenience for the user, and is a reasonable
39// choice since these headers are light and bring in little dependency.
40#include "neml2/base/LabeledAxis.h"
41
42namespace neml2
43{
44class Model;
45
48 std::function<void(const Model &,
49 const std::map<VariableName, std::unique_ptr<VariableBase>> &,
50 const std::map<VariableName, std::unique_ptr<VariableBase>> &)>;
51
58std::shared_ptr<Model> load_model(const std::filesystem::path & path, const std::string & mname);
59
67class Model : public std::enable_shared_from_this<Model>,
68 public Data,
69 public ParameterStore,
70 public VariableStore,
71 public NonlinearSystem,
72 public DependencyDefinition<VariableName>,
74{
75public:
82 {
83 std::vector<Size> batch_dims;
84 at::DispatchKey dispatch_key;
85 bool operator==(const TraceSchema & other) const;
86 bool operator<(const TraceSchema & other) const;
87 };
88
90
96 Model(const OptionSet & options);
97
98 void setup() override;
99
100 void diagnose() const override;
101
103 virtual bool defines_values() const { return _defines_value; }
104
106 virtual bool defines_derivatives() const { return _defines_dvalue; }
107
109 virtual bool defines_second_derivatives() const { return _defines_d2value; }
110
112 virtual bool is_nonlinear_system() const { return _nonlinear_system; }
113
115 virtual bool is_jit_enabled() const { return _jit; }
116
118 virtual void to(const TensorOptions & options);
119
121 const std::vector<std::shared_ptr<Model>> & registered_models() const
122 {
123 return _registered_models;
124 }
125
126 std::shared_ptr<Model> registered_model(const std::string & name) const;
127
129 void register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param);
130
132 bool has_nl_param(bool recursive = false) const;
133
140 const VariableBase * nl_param(const std::string &) const;
141
143 virtual std::map<std::string, NonlinearParameter>
144 named_nonlinear_parameters(bool recursive = false) const;
145
147 std::set<VariableName> consumed_items() const override;
149 std::set<VariableName> provided_items() const override;
150
152 void request_AD(VariableBase & y, const VariableBase & u);
153
155 void request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2);
156
158 void register_callback(const ModelCallback & callback);
159
161 void register_callback_recursive(const ModelCallback & callback);
162
164 void forward(bool out, bool dout, bool d2out);
165
174 void forward_maybe_jit(bool out, bool dout, bool d2out);
175
177 std::string variable_name_lookup(const ATensor & var);
178
180 virtual ValueMap value(const ValueMap & in);
181 virtual ValueMap value(ValueMap && in);
182
184 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(const ValueMap & in);
185 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(ValueMap && in);
186
188 virtual DerivMap dvalue(const ValueMap & in);
189 virtual DerivMap dvalue(ValueMap && in);
190
192 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
194 virtual std::tuple<ValueMap, DerivMap, SecDerivMap> value_and_dvalue_and_d2value(ValueMap && in);
195
197 virtual SecDerivMap d2value(const ValueMap & in);
198 virtual SecDerivMap d2value(ValueMap && in);
199
201 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(const ValueMap & in);
202 virtual std::tuple<DerivMap, SecDerivMap> dvalue_and_d2value(ValueMap && in);
203
205 friend class ParameterStore;
206
208 friend class ComposedModel;
209
210protected:
211 void diagnostic_assert_state(const VariableBase & v) const;
212 void diagnostic_assert_old_state(const VariableBase & v) const;
213 void diagnostic_assert_force(const VariableBase & v) const;
214 void diagnostic_assert_old_force(const VariableBase & v) const;
215 void diagnostic_assert_residual(const VariableBase & v) const;
216 void diagnostic_check_input_variable(const VariableBase & v) const;
218
220 void diagnose_nl_sys() const;
221
222 virtual void link_input_variables();
223 virtual void link_input_variables(Model * submodel);
224 virtual void link_output_variables();
225 virtual void link_output_variables(Model * submodel);
226
227 void clear_input() override;
228 void clear_output() override;
229 void zero_input() override;
230 void zero_output() override;
231
233 void check_precision() const;
234
248 virtual void request_AD() {}
249
251 virtual void set_value(bool out, bool dout_din, bool d2out_din2) = 0;
252
265 template <typename T = Model, typename = typename std::enable_if_t<std::is_base_of_v<Model, T>>>
266 T & register_model(const std::string & name, bool nonlinear = false, bool merge_input = true)
267 {
268 auto model_name =
269 input_options().contains(name) ? input_options().get<std::string>(name) : name;
270 if (model_name == this->name())
271 throw SetupException("Model named '" + this->name() +
272 "' is trying to register itself as a sub-model. This is not allowed.");
273
274 OptionSet extra_opts;
275 extra_opts.set<NEML2Object *>("_host") = host();
276 extra_opts.set<bool>("_nonlinear_system") = nonlinear;
277
278 if (!host()->factory())
279 throw SetupException("Internal error: Host object '" + host()->name() +
280 "' does not have a factory set.");
281 auto model = host()->factory()->get_object<T>("Models", model_name, extra_opts);
282 if (std::find(_registered_models.begin(), _registered_models.end(), model) !=
283 _registered_models.end())
284 throw SetupException("Model named '" + model_name + "' has already been registered.");
285
286 if (merge_input)
287 for (auto && [name, var] : model->input_variables())
289
290 _registered_models.push_back(model);
291 return *model;
292 }
293
294 void assign_input_stack(jit::Stack & stack);
295
296 jit::Stack collect_input_stack() const;
297
298 void set_guess(const Sol<false> &) override;
299
300 void assemble(Res<false> *, Jac<false> *) override;
301
303 std::vector<std::shared_ptr<Model>> _registered_models;
304
305private:
307 void call_callbacks() const;
308
309 template <typename T>
310 void forward_helper(T && in, bool out, bool dout, bool d2out)
311 {
313 zero_input();
314 assign_input(std::forward<T>(in));
315 zero_output();
316 forward_maybe_jit(out, dout, d2out);
317 }
318
321 bool AD_need_value(bool dout, bool d2out) const;
322
324 void enable_AD();
325
327 void extract_AD_derivatives(bool dout, bool d2out);
328
330 std::size_t forward_operator_index(bool out, bool dout, bool d2out) const;
331
333 TraceSchema compute_trace_schema() const;
334
337 bool _defines_value;
338 bool _defines_dvalue;
339 bool _defines_d2value;
341
343 bool _nonlinear_system;
344
346 std::map<std::string, NonlinearParameter> _nl_params;
347
350 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
351 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
352 _ad_secderivs;
353 std::set<VariableBase *> _ad_args;
355
357 const bool _jit;
358
360 const bool _production;
361
378 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8> _traced_functions;
379
381 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8>
382 _traced_functions_nl_sys;
383
385 std::vector<ModelCallback> _callbacks;
386};
387
388std::ostream & operator<<(std::ostream & os, const Model & model);
389} // namespace neml2
Data(const OptionSet &options)
Construct a new Data object.
Definition Data.cxx:38
The base class for all constitutive models.
Definition Model.h:74
friend class ComposedModel
ComposedModel's set_value need to call submodel's set_value.
Definition Model.h:208
void clear_input() override
Definition Model.cxx:312
std::vector< std::shared_ptr< Model > > _registered_models
Models this model may use during its evaluation.
Definition Model.h:303
virtual void to(const TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:112
void check_precision() const
Check the current default precision and warn if it's not double precision.
Definition Model.cxx:496
void diagnostic_assert_force(const VariableBase &v) const
Definition Model.cxx:193
virtual std::tuple< ValueMap, DerivMap > value_and_dvalue(const ValueMap &in)
Convenient shortcut to construct and return the model value and its derivative.
Definition Model.cxx:531
virtual void link_input_variables()
Definition Model.cxx:252
void register_nonlinear_parameter(const std::string &pname, const NonlinearParameter &param)
Register a nonlinear parameter.
Definition Model.cxx:660
void register_callback(const ModelCallback &callback)
Register a callback to be called when the model is evaluated.
Definition Model.cxx:364
virtual bool defines_derivatives() const
Whether this model defines first derivatives.
Definition Model.h:106
bool has_nl_param(bool recursive=false) const
Whether this parameter store has any nonlinear parameter.
Definition Model.cxx:670
void diagnostic_check_output_variable(const VariableBase &v) const
Definition Model.cxx:235
virtual bool defines_values() const
Whether this model defines output values.
Definition Model.h:103
const std::vector< std::shared_ptr< Model > > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:121
void forward_maybe_jit(bool out, bool dout, bool d2out)
Forward operator with jit.
Definition Model.cxx:411
virtual SecDerivMap d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's second derivative.
Definition Model.cxx:627
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:716
friend class ParameterStore
Declaration of nonlinear parameters may require manipulation of input.
Definition Model.h:205
virtual bool defines_second_derivatives() const
Whether this model defines second derivatives.
Definition Model.h:109
void diagnostic_assert_state(const VariableBase &v) const
Definition Model.cxx:180
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:761
void forward(bool out, bool dout, bool d2out)
Forward operator without jit.
Definition Model.cxx:379
virtual std::tuple< ValueMap, DerivMap, SecDerivMap > value_and_dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's value, first and second derivative.
Definition Model.cxx:577
virtual ValueMap value(const ValueMap &in)
Convenient shortcut to construct and return the model value.
Definition Model.cxx:509
std::string variable_name_lookup(const ATensor &var)
Look up the name of a variable in the traced graph.
Definition Model.cxx:466
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:266
virtual std::map< std::string, NonlinearParameter > named_nonlinear_parameters(bool recursive=false) const
Get all nonlinear parameters.
Definition Model.cxx:689
void zero_input() override
Definition Model.cxx:328
void diagnostic_check_input_variable(const VariableBase &v) const
Definition Model.cxx:213
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:112
virtual std::tuple< DerivMap, SecDerivMap > dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's first and second derivative.
Definition Model.cxx:603
std::shared_ptr< Model > registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:649
void diagnostic_assert_old_state(const VariableBase &v) const
Definition Model.cxx:186
virtual void link_output_variables()
Definition Model.cxx:269
jit::Stack collect_input_stack() const
Definition Model.cxx:741
void diagnostic_assert_old_force(const VariableBase &v) const
Definition Model.cxx:199
void register_callback_recursive(const ModelCallback &callback)
Register a callback on this and all submodels.
Definition Model.cxx:370
void setup() override
Setup this object.
Definition Model.cxx:126
void diagnose() const override
Check for common problems.
Definition Model.cxx:140
void clear_output() override
Definition Model.cxx:320
static OptionSet expected_options()
Definition Model.cxx:62
virtual DerivMap dvalue(const ValueMap &in)
Convenient shortcut to construct and return the derivative.
Definition Model.cxx:555
void assign_input_stack(jit::Stack &stack)
Definition Model.cxx:723
void diagnostic_assert_residual(const VariableBase &v) const
Definition Model.cxx:206
virtual void request_AD()
Definition Model.h:248
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:96
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:754
virtual bool is_jit_enabled() const
Whether JIT is enabled.
Definition Model.h:115
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
Definition Model.cxx:683
void diagnose_nl_sys() const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:156
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:709
void zero_output() override
Definition Model.cxx:336
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:51
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:83
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:150
Factory * factory() const
Get the factory that created this object.
Definition NEML2Object.h:92
const OptionSet & input_options() const
Definition NEML2Object.h:69
NonlinearSystem(const NonlinearSystem &)=default
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:51
T get(const std::string &) const
Definition OptionSet.h:242
bool contains(const std::string &) const
Definition OptionSet.cxx:47
T & set(const std::string &)
Definition OptionSet.h:254
Definition errors.h:50
Base class of variable.
Definition Variable.h:52
VariableStore(Model *object)
Definition VariableStore.cxx:36
void assign_input(const ValueMap &vals)
Definition VariableStore.cxx:257
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.cxx:120
Definition DiagnosticsInterface.cxx:29
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
at::Tensor ATensor
Definition types.h:38
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types_fwd.h:35
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
std::shared_ptr< Model > load_model(const std::filesystem::path &path, const std::string &mname)
A convenient function to load an input file and get a model.
Definition Model.cxx:41
c10::TensorOptions TensorOptions
Definition types.h:60
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
std::function< void(const Model &, const std::map< VariableName, std::unique_ptr< VariableBase > > &, const std::map< VariableName, std::unique_ptr< VariableBase > > &)> ModelCallback
typedef giving the call signature for a model callback
Definition Model.h:47
Schema for the traced forward operators.
Definition Model.h:82
bool operator==(const TraceSchema &other) const
Definition Model.cxx:48
std::vector< Size > batch_dims
Definition Model.h:83
at::DispatchKey dispatch_key
Definition Model.h:84
bool operator<(const TraceSchema &other) const
Definition Model.cxx:54
Nonlinear parameter.
Definition NonlinearParameter.h:51