27#include "neml2/base/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/solvers/NonlinearSystem.h"
44class Model :
public std::enable_shared_from_this<Model>,
63 virtual void to(
const torch::TensorOptions & options);
65 void diagnose(std::vector<Diagnosis> &)
const override;
114 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
130 void setup()
override;
169 template <
typename T,
typename =
typename std::enable_if_t<std::is_base_of_v<Model, T>>>
175 "' is trying to register itself as a sub-model. This is not allowed.");
181 auto model = Factory::get_object_ptr<Model>(
"Models",
name,
extra_opts);
184 for (
auto && [
name,
var] : model->input_variables())
188 return *(std::dynamic_pointer_cast<T>(model));
201 bool AD_need_value(
bool dout,
bool d2out)
const;
207 void extract_AD_derivatives(
bool dout,
bool d2out);
210 bool _nonlinear_system;
213 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
214 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
216 std::set<VariableBase *> _ad_args;
Definition ComposedModel.h:35
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
Definition DependencyDefinition.h:40
Interface for object making diagnostics about common setup errors.
Definition DiagnosticsInterface.h:47
The base class for all constitutive models.
Definition Model.h:51
void clear_input() override
Definition Model.cxx:171
virtual void dvalue()
Evalute the derivative.
Definition Model.cxx:277
void diagnose(std::vector< Diagnosis > &) const override
Check for common problems.
Definition Model.cxx:68
virtual void dvalue_and_d2value()
Evalute the first and second derivatives.
Definition Model.cxx:293
virtual void value_and_dvalue_and_d2value()
Evalute the model and compute its first and second derivatives.
Definition Model.cxx:285
virtual void link_input_variables()
Definition Model.cxx:123
const std::vector< Model * > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:71
std::vector< Model * > _registered_models
Models this model may use during its evaluation.
Definition Model.h:196
virtual void d2value()
Evalute the second derivatives.
Definition Model.cxx:301
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:327
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:341
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:170
void zero_input() override
Definition Model.cxx:187
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:68
virtual void link_output_variables()
Definition Model.cxx:140
virtual void to(const torch::TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:58
void setup() override
Setup this object.
Definition Model.cxx:109
virtual void value()
Evalute the model.
Definition Model.cxx:263
void clear_output() override
Definition Model.cxx:179
static OptionSet expected_options()
Definition Model.cxx:33
virtual void value_and_dvalue()
Evalute the model and compute its derivative.
Definition Model.cxx:269
void diagnose_nl_sys(std::vector< Diagnosis > &diagnoses) const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:84
virtual void request_AD()
Definition Model.h:149
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:47
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:334
Model * registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:309
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:320
void zero_output() override
Definition Model.cxx:195
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:38
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:70
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:95
Definition of a nonlinear system of equations.
Definition NonlinearSystem.h:37
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:85
Interface for object which can store parameters.
Definition ParameterStore.h:46
Base class of variable.
Definition Variable.h:47
Definition VariableStore.h:40
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.h:156
Definition CrossRef.cxx:31
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types.h:34
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types.h:35
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types.h:36
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64