27#include "neml2/models/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/models/Variable.h"
34#include "neml2/solvers/NonlinearSystem.h"
35#include "neml2/models/NonlinearParameter.h"
40#include "neml2/base/LabeledAxis.h"
48 std::function<void(
const Model &,
49 const std::map<
VariableName, std::unique_ptr<VariableBase>> &,
50 const std::map<
VariableName, std::unique_ptr<VariableBase>> &)>;
58std::shared_ptr<Model>
load_model(
const std::filesystem::path & path,
const std::string & mname);
67class Model :
public std::enable_shared_from_this<Model>,
98 void setup()
override;
143 virtual std::map<std::string, NonlinearParameter>
164 void forward(
bool out,
bool dout,
bool d2out);
192 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
251 virtual void set_value(
bool out,
bool dout_din,
bool d2out_din2) = 0;
265 template <
typename T = Model,
typename =
typename std::enable_if_t<std::is_base_of_v<Model, T>>>
270 if (model_name == this->
name())
272 "' is trying to register itself as a sub-model. This is not allowed.");
276 extra_opts.
set<
bool>(
"_nonlinear_system") = nonlinear;
280 "' does not have a factory set.");
281 auto model =
host()->factory()->get_object<T>(
"Models", model_name, extra_opts);
284 throw SetupException(
"Model named '" + model_name +
"' has already been registered.");
287 for (
auto && [
name, var] : model->input_variables())
298 void set_guess(
const Sol<false> &)
override;
300 void assemble(Res<false> *, Jac<false> *)
override;
307 void call_callbacks()
const;
309 template <
typename T>
310 void forward_helper(T && in,
bool out,
bool dout,
bool d2out)
321 bool AD_need_value(
bool dout,
bool d2out)
const;
327 void extract_AD_derivatives(
bool dout,
bool d2out);
330 std::size_t forward_operator_index(
bool out,
bool dout,
bool d2out)
const;
333 TraceSchema compute_trace_schema()
const;
338 bool _defines_dvalue;
339 bool _defines_d2value;
343 bool _nonlinear_system;
346 std::map<std::string, NonlinearParameter> _nl_params;
350 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
351 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
353 std::set<VariableBase *> _ad_args;
360 const bool _production;
378 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8> _traced_functions;
381 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8>
382 _traced_functions_nl_sys;
385 std::vector<ModelCallback> _callbacks;
388std::ostream &
operator<<(std::ostream & os,
const Model & model);
Data(const OptionSet &options)
Construct a new Data object.
Definition Data.cxx:38
DependencyDefinition()=default
DiagnosticsInterface()=delete
The base class for all constitutive models.
Definition Model.h:74
friend class ComposedModel
ComposedModel's set_value need to call submodel's set_value.
Definition Model.h:208
void clear_input() override
Definition Model.cxx:312
std::vector< std::shared_ptr< Model > > _registered_models
Models this model may use during its evaluation.
Definition Model.h:303
virtual void to(const TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:112
void check_precision() const
Check the current default precision and warn if it's not double precision.
Definition Model.cxx:496
void diagnostic_assert_force(const VariableBase &v) const
Definition Model.cxx:193
virtual std::tuple< ValueMap, DerivMap > value_and_dvalue(const ValueMap &in)
Convenient shortcut to construct and return the model value and its derivative.
Definition Model.cxx:531
virtual void link_input_variables()
Definition Model.cxx:252
void register_nonlinear_parameter(const std::string &pname, const NonlinearParameter ¶m)
Register a nonlinear parameter.
Definition Model.cxx:660
void register_callback(const ModelCallback &callback)
Register a callback to be called when the model is evaluated.
Definition Model.cxx:364
virtual bool defines_derivatives() const
Whether this model defines first derivatives.
Definition Model.h:106
bool has_nl_param(bool recursive=false) const
Whether this parameter store has any nonlinear parameter.
Definition Model.cxx:670
void diagnostic_check_output_variable(const VariableBase &v) const
Definition Model.cxx:235
virtual bool defines_values() const
Whether this model defines output values.
Definition Model.h:103
const std::vector< std::shared_ptr< Model > > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:121
void forward_maybe_jit(bool out, bool dout, bool d2out)
Forward operator with jit.
Definition Model.cxx:411
virtual SecDerivMap d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's second derivative.
Definition Model.cxx:627
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:716
friend class ParameterStore
Declaration of nonlinear parameters may require manipulation of input.
Definition Model.h:205
virtual bool defines_second_derivatives() const
Whether this model defines second derivatives.
Definition Model.h:109
void diagnostic_assert_state(const VariableBase &v) const
Definition Model.cxx:180
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:761
void forward(bool out, bool dout, bool d2out)
Forward operator without jit.
Definition Model.cxx:379
virtual std::tuple< ValueMap, DerivMap, SecDerivMap > value_and_dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's value, first and second derivative.
Definition Model.cxx:577
virtual ValueMap value(const ValueMap &in)
Convenient shortcut to construct and return the model value.
Definition Model.cxx:509
std::string variable_name_lookup(const ATensor &var)
Look up the name of a variable in the traced graph.
Definition Model.cxx:466
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:266
virtual std::map< std::string, NonlinearParameter > named_nonlinear_parameters(bool recursive=false) const
Get all nonlinear parameters.
Definition Model.cxx:689
void zero_input() override
Definition Model.cxx:328
void diagnostic_check_input_variable(const VariableBase &v) const
Definition Model.cxx:213
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:112
virtual std::tuple< DerivMap, SecDerivMap > dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's first and second derivative.
Definition Model.cxx:603
std::shared_ptr< Model > registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:649
void diagnostic_assert_old_state(const VariableBase &v) const
Definition Model.cxx:186
virtual void link_output_variables()
Definition Model.cxx:269
jit::Stack collect_input_stack() const
Definition Model.cxx:741
void diagnostic_assert_old_force(const VariableBase &v) const
Definition Model.cxx:199
void register_callback_recursive(const ModelCallback &callback)
Register a callback on this and all submodels.
Definition Model.cxx:370
void setup() override
Setup this object.
Definition Model.cxx:126
void diagnose() const override
Check for common problems.
Definition Model.cxx:140
void clear_output() override
Definition Model.cxx:320
static OptionSet expected_options()
Definition Model.cxx:62
virtual DerivMap dvalue(const ValueMap &in)
Convenient shortcut to construct and return the derivative.
Definition Model.cxx:555
void assign_input_stack(jit::Stack &stack)
Definition Model.cxx:723
void diagnostic_assert_residual(const VariableBase &v) const
Definition Model.cxx:206
virtual void request_AD()
Definition Model.h:248
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:96
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:754
virtual bool is_jit_enabled() const
Whether JIT is enabled.
Definition Model.h:115
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
Definition Model.cxx:683
void diagnose_nl_sys() const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:156
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:709
void zero_output() override
Definition Model.cxx:336
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:51
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:83
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:150
Factory * factory() const
Get the factory that created this object.
Definition NEML2Object.h:92
const OptionSet & input_options() const
Definition NEML2Object.h:69
NonlinearSystem(const NonlinearSystem &)=default
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:51
T get(const std::string &) const
Definition OptionSet.h:242
bool contains(const std::string &) const
Definition OptionSet.cxx:47
T & set(const std::string &)
Definition OptionSet.h:254
Base class of variable.
Definition Variable.h:52
VariableStore(Model *object)
Definition VariableStore.cxx:36
void assign_input(const ValueMap &vals)
Definition VariableStore.cxx:257
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.cxx:120
Definition DiagnosticsInterface.cxx:29
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
at::Tensor ATensor
Definition types.h:38
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types_fwd.h:35
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
std::shared_ptr< Model > load_model(const std::filesystem::path &path, const std::string &mname)
A convenient function to load an input file and get a model.
Definition Model.cxx:41
c10::TensorOptions TensorOptions
Definition types.h:60
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
std::function< void(const Model &, const std::map< VariableName, std::unique_ptr< VariableBase > > &, const std::map< VariableName, std::unique_ptr< VariableBase > > &)> ModelCallback
typedef giving the call signature for a model callback
Definition Model.h:47
Schema for the traced forward operators.
Definition Model.h:82
bool operator==(const TraceSchema &other) const
Definition Model.cxx:48
std::vector< Size > batch_dims
Definition Model.h:83
at::DispatchKey dispatch_key
Definition Model.h:84
bool operator<(const TraceSchema &other) const
Definition Model.cxx:54
Nonlinear parameter.
Definition NonlinearParameter.h:51