NEML2 2.1.0
Loading...
Searching...
No Matches
Model.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
29
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/models/NonlinearParameter.h"
34#include "neml2/tensors/jit.h"
35
36// These headers are not directly used by Model, but are included here so that derived classes do
37// not have to include them separately. This is a convenience for the user, and is a reasonable
38// choice since these headers are light and bring in little dependency.
39#include "neml2/models/Variable.h"
40#include "neml2/tensors/Derivative.h"
41
42namespace neml2
43{
44class Model;
45
52
53// Guard a region where implicit solve is being performed
66
76class Model : public std::enable_shared_from_this<Model>,
77 public Data,
78 public ParameterStore,
79 public VariableStore,
80 public DependencyDefinition<VariableName>,
82{
83public:
90 {
91 std::vector<Size> dynamic_dims;
92 std::vector<TensorShape> intmd_shapes;
93 at::DispatchKey dispatch_key;
94 bool operator==(const EvaluationSchema & other) const;
95 bool operator!=(const EvaluationSchema & other) const;
96 bool operator<(const EvaluationSchema & other) const;
97 };
98
100
106 Model(const OptionSet & options);
107
108 void setup() override;
109
110 void diagnose() const override;
111
113 virtual bool defines_values() const { return _defines_value; }
114
116 virtual bool defines_derivatives() const { return _defines_dvalue; }
117
119 virtual bool defines_second_derivatives() const { return _defines_d2value; }
120
122 virtual bool is_jit_enabled() const { return _jit; }
123
125 virtual void to(const TensorOptions & options);
126
128 const std::vector<std::shared_ptr<Model>> & registered_models() const
129 {
130 return _registered_models;
131 }
132
133 std::shared_ptr<Model> registered_model(const std::string & name) const;
134
136 void register_nonlinear_parameter(const std::string & pname, const NonlinearParameter & param);
137
139 bool has_nl_param(bool recursive = false) const;
140
147 const VariableBase * nl_param(const std::string &) const;
148
150 virtual std::map<std::string, NonlinearParameter>
151 named_nonlinear_parameters(bool recursive = false) const;
152
154 std::set<VariableName> consumed_items() const override;
156 std::set<VariableName> provided_items() const override;
157
160
162 void request_AD(VariableBase & y, const VariableBase & u1, const VariableBase & u2);
163
164 void clear_input() override;
165 void clear_output() override;
166 void zero_undefined_input() override;
167
176 void forward_maybe_jit(bool out, bool dout, bool d2out);
177
179 std::string variable_name_lookup(const ATensor & var) const;
180
182 virtual ValueMap value(const ValueMap & in);
183
185 virtual DerivMap dvalue(const ValueMap & in);
186
188 virtual std::tuple<ValueMap, DerivMap> value_and_dvalue(const ValueMap & in);
189
192 void
193 set_output_derivative_filter(const std::vector<std::pair<VariableName, VariableName>> & derivs);
194
196 friend class ParameterStore;
197
199 friend class ComposedModel;
200
203
205 friend class ImplicitUpdate;
206
207protected:
211 const std::vector<std::pair<VariableName, VariableName>> & derivs);
212
213 virtual void link_input_variables();
214 virtual void link_input_variables(Model * submodel);
215 virtual void link_output_variables();
216 virtual void link_output_variables(Model * submodel);
217
219 void check_precision() const;
220
222 virtual std::string failed_graph_execution_hint() const;
223
237 virtual void request_AD() {}
238
240 void forward(bool out, bool dout, bool d2out);
241
243 virtual void set_value(bool out, bool dout_din, bool d2out_din2) = 0;
244
256 template <typename T = Model, typename = typename std::enable_if_t<std::is_base_of_v<Model, T>>>
257 T & register_model(const std::string & name, bool merge_input = true)
258 {
259 auto model_name =
260 input_options().contains(name) ? input_options().get<std::string>(name) : name;
261 if (model_name == this->name())
262 throw SetupException("Model named '" + this->name() +
263 "' is trying to register itself as a sub-model. This is not allowed.");
264
265 OptionSet extra_opts;
266 extra_opts.add_private<NEML2Object *>("_host", host());
267
268 if (!host()->factory())
269 throw SetupException("Internal error: Host object '" + host()->name() +
270 "' does not have a factory set.");
271 auto model = host()->factory()->get_object<T>("Models", model_name, extra_opts);
272
273 register_model(model, merge_input);
274 return *model;
275 }
276
288 template <typename T = Model, typename = typename std::enable_if_t<std::is_base_of_v<Model, T>>>
289 void register_model(const std::shared_ptr<T> & model, bool merge_input = true)
290 {
291 if (std::find(_registered_models.begin(), _registered_models.end(), model) !=
292 _registered_models.end())
293 throw SetupException("Model named '" + model->name() + "' has already been registered.");
294
295 if (merge_input)
296 for (auto && [name, var] : model->input_variables())
297 if (input_variables().find(name) == input_variables().end())
299
300 _registered_models.push_back(model);
301 }
302
303 void assign_input_stack(jit::Stack & stack);
304
305 jit::Stack collect_input_stack() const;
306
308 std::vector<std::shared_ptr<Model>> _registered_models;
309
310private:
311 template <typename T>
312 void forward_helper(T && in, bool out, bool dout, bool d2out)
313 {
315 assign_input(std::forward<T>(in));
317 forward_maybe_jit(out, dout, d2out);
318 }
319
321 void enable_AD();
322
324 void extract_AD_derivatives(bool dout, bool d2out);
325
327 std::size_t forward_operator_index(bool out, bool dout, bool d2out) const;
328
330 EvaluationSchema calculate_eval_schema() const;
331
334 bool _defines_value;
335 bool _defines_dvalue;
336 bool _defines_d2value;
338
340 std::map<std::string, NonlinearParameter> _nl_params;
341
344 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
345 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
346 _ad_secderivs;
347 std::set<VariableBase *> _ad_args;
349
351 const bool _jit;
352
354 const bool _production;
355
372 std::array<std::map<EvaluationSchema, std::unique_ptr<jit::GraphFunction>>, 8> _traced_functions;
373
375 std::array<std::map<EvaluationSchema, std::unique_ptr<jit::GraphFunction>>, 8>
376 _traced_functions_nl_sys;
377
379 bool _currently_assembling_nonlinear_system = false;
380};
381
382std::ostream & operator<<(std::ostream & os, const Model & model);
383} // namespace neml2
Data(const OptionSet &options)
Construct a new Data object.
The base class for all constitutive models.
Definition Model.h:82
void request_AD(VariableBase &y, const VariableBase &u)
Request to use AD to compute the first derivative of a variable.
virtual std::string failed_graph_execution_hint() const
Additional hint to include in the error message when an exception is encountered during execution of ...
friend class ComposedModel
ComposedModel's set_value need to call submodel's set_value.
Definition Model.h:199
void register_model(const std::shared_ptr< T > &model, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:289
virtual void link_output_variables()
void clear_input() override
std::vector< std::shared_ptr< Model > > _registered_models
Models this model may use during its evaluation.
Definition Model.h:308
virtual void link_output_variables(Model *submodel)
virtual void to(const TensorOptions &options)
Send model to a different device or dtype.
virtual std::tuple< ValueMap, DerivMap > value_and_dvalue(const ValueMap &in)
Convenient shortcut to construct and return the model value and its derivative.
void check_precision() const
Check the current default precision and warn if it's not double precision.
void zero_undefined_input() override
Fill undefined input variables with zeros.
void register_nonlinear_parameter(const std::string &pname, const NonlinearParameter &param)
Register a nonlinear parameter.
virtual bool defines_derivatives() const
Whether this model defines first derivatives.
Definition Model.h:116
friend class ModelNonlinearSystem
ModelNonlinearSystem needs access to some setup methods.
Definition Model.h:202
bool has_nl_param(bool recursive=false) const
Whether this parameter store has any nonlinear parameter.
virtual bool defines_values() const
Whether this model defines output values.
Definition Model.h:113
const std::vector< std::shared_ptr< Model > > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:128
void forward_maybe_jit(bool out, bool dout, bool d2out)
Forward operator with jit.
friend class ImplicitUpdate
ImplicitUpdate sets the nl_sys derivative filter on its sub-model at construction time.
Definition Model.h:205
virtual ValueMap value(const ValueMap &in)
Convenient shortcut to construct and return the model value.
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
virtual DerivMap dvalue(const ValueMap &in)
Convenient shortcut to construct and return the derivative.
friend class ParameterStore
Declaration of nonlinear parameters may require manipulation of input.
Definition Model.h:196
virtual bool defines_second_derivatives() const
Whether this model defines second derivatives.
Definition Model.h:119
virtual void link_input_variables()
void request_AD(VariableBase &y, const VariableBase &u1, const VariableBase &u2)
Request to use AD to compute the second derivative of a variable.
void forward(bool out, bool dout, bool d2out)
Forward operator without jit.
void set_output_derivative_filter_nl_sys(const std::vector< std::pair< VariableName, VariableName > > &derivs)
virtual std::map< std::string, NonlinearParameter > named_nonlinear_parameters(bool recursive=false) const
Get all nonlinear parameters.
std::string variable_name_lookup(const ATensor &var) const
Look up the name of a variable in the traced graph.
std::shared_ptr< Model > registered_model(const std::string &name) const
Get a registered model by its name.
jit::Stack collect_input_stack() const
void setup() override
Setup this object.
void diagnose() const override
Check for common problems.
void clear_output() override
void set_output_derivative_filter(const std::vector< std::pair< VariableName, VariableName > > &derivs)
static OptionSet expected_options()
void assign_input_stack(jit::Stack &stack)
virtual void request_AD()
Definition Model.h:237
Model(const OptionSet &options)
Construct a new Model object.
virtual void link_input_variables(Model *submodel)
T & register_model(const std::string &name, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:257
virtual bool is_jit_enabled() const
Whether JIT is enabled.
Definition Model.h:122
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:52
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:84
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:192
Factory * factory() const
Get the factory that created this object.
Definition NEML2Object.h:93
const OptionSet & input_options() const
Definition NEML2Object.h:70
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:54
T get(const std::string &) const
Definition OptionSet.h:381
void add_private(const std::string &name, const T &default_value)
Create a private option with its default value.
Definition OptionSet.h:488
bool contains(const std::string &) const
Definition errors.h:50
Base class of variable.
Definition VariableBase.h:58
void assign_input(const ValueMap &, bool allow_nonexistent=false)
VariableStore(Model *object)
VariableStorage & input_variables()
Definition VariableStore.h:63
const VariableBase * clone_input_variable(const VariableBase &var, std::optional< VariableName > new_name=std::nullopt)
Clone a variable and put it on the input axis.
Definition DiagnosticsInterface.h:31
std::map< VariableName, ValueMap > DerivMap
Definition Tensor.h:40
std::ostream & operator<<(std::ostream &, const EnumSelection &)
at::Tensor ATensor
Definition types.h:42
std::map< VariableName, Tensor > ValueMap
Definition Tensor.h:39
bool & currently_assembling_nonlinear_system()
c10::TensorOptions TensorOptions
Definition types.h:66
AssemblyingNonlinearSystem & operator=(const AssemblyingNonlinearSystem &)=delete
AssemblyingNonlinearSystem(AssemblyingNonlinearSystem &&)=delete
AssemblyingNonlinearSystem & operator=(AssemblyingNonlinearSystem &&)=delete
AssemblyingNonlinearSystem(const AssemblyingNonlinearSystem &)=delete
AssemblyingNonlinearSystem(bool assembling=true)
const bool prev_bool
Definition Model.h:64
Schema for the traced forward operators.
Definition Model.h:90
bool operator==(const EvaluationSchema &other) const
at::DispatchKey dispatch_key
Definition Model.h:93
bool operator!=(const EvaluationSchema &other) const
std::vector< Size > dynamic_dims
Definition Model.h:91
bool operator<(const EvaluationSchema &other) const
std::vector< TensorShape > intmd_shapes
Definition Model.h:92
Nonlinear parameter.
Definition NonlinearParameter.h:51