27#include "neml2/models/DependencyDefinition.h"
28#include "neml2/base/DiagnosticsInterface.h"
30#include "neml2/models/Data.h"
31#include "neml2/models/ParameterStore.h"
32#include "neml2/models/VariableStore.h"
33#include "neml2/solvers/NonlinearSystem.h"
34#include "neml2/models/NonlinearParameter.h"
39#include "neml2/base/TensorName.h"
40#include "neml2/base/LabeledAxis.h"
41#include "neml2/tensors/TensorValue.h"
42#include "neml2/models/Variable.h"
66Model &
load_model(
const std::filesystem::path & path,
const std::string & mname);
78Model &
reload_model(
const std::filesystem::path & path,
const std::string & mname);
90class Model :
public std::enable_shared_from_this<Model>,
121 void setup()
override;
163 virtual std::map<std::string, NonlinearParameter>
178 void forward(
bool out,
bool dout,
bool d2out);
206 virtual std::tuple<ValueMap, DerivMap, SecDerivMap>
262 virtual void set_value(
bool out,
bool dout_din,
bool d2out_din2) = 0;
276 template <
typename T = Model,
typename =
typename std::enable_if_t<std::is_base_of_v<Model, T>>>
281 "' is trying to register itself as a sub-model. This is not allowed.");
285 extra_opts.
set<
bool>(
"_nonlinear_system") = nonlinear;
293 for (
auto && [
name, var] : model->input_variables())
304 void set_guess(
const Sol<false> &)
override;
306 void assemble(Res<false> *, Jac<false> *)
override;
312 template <
typename T>
313 void forward_helper(T && in,
bool out,
bool dout,
bool d2out)
324 bool AD_need_value(
bool dout,
bool d2out)
const;
330 void extract_AD_derivatives(
bool dout,
bool d2out);
333 std::size_t forward_operator_index(
bool out,
bool dout,
bool d2out)
const;
336 TraceSchema compute_trace_schema()
const;
341 bool _defines_dvalue;
342 bool _defines_d2value;
346 bool _nonlinear_system;
349 std::map<std::string, NonlinearParameter> _nl_params;
353 std::map<VariableBase *, std::set<const VariableBase *>> _ad_derivs;
354 std::map<VariableBase *, std::map<const VariableBase *, std::set<const VariableBase *>>>
356 std::set<VariableBase *> _ad_args;
363 const bool _production;
381 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8> _traced_functions;
384 std::array<std::map<TraceSchema, std::unique_ptr<jit::GraphFunction>>, 8>
385 _traced_functions_nl_sys;
90class Model :
public std::enable_shared_from_this<Model>, {
…};
388std::ostream &
operator<<(std::ostream & os,
const Model & model);
Data(const OptionSet &options)
Construct a new Data object.
Definition Data.cxx:38
DependencyDefinition()=default
DiagnosticsInterface()=delete
static std::shared_ptr< T > get_object_ptr(const std::string §ion, const std::string &name, const OptionSet &additional_options=OptionSet(), bool force_create=true)
Retrive an object pointer under the given section with the given object name.
Definition Factory.h:165
The base class for all constitutive models.
Definition Model.h:97
friend class ComposedModel
ComposedModel's set_value need to call submodel's set_value.
Definition Model.h:222
void clear_input() override
Definition Model.cxx:326
virtual void to(const TensorOptions &options)
Send model to a different device or dtype.
Definition Model.cxx:126
void diagnostic_assert_force(const VariableBase &v) const
Definition Model.cxx:207
virtual std::tuple< ValueMap, DerivMap > value_and_dvalue(const ValueMap &in)
Convenient shortcut to construct and return the model value and its derivative.
Definition Model.cxx:527
virtual void link_input_variables()
Definition Model.cxx:266
const std::vector< Model * > & registered_models() const
The models that may be used during the evaluation of this model.
Definition Model.h:144
std::vector< Model * > _registered_models
Models this model may use during its evaluation.
Definition Model.h:309
void register_nonlinear_parameter(const std::string &pname, const NonlinearParameter ¶m)
Register a nonlinear parameter.
Definition Model.cxx:656
virtual bool defines_derivatives() const
Whether this model defines first derivatives.
Definition Model.h:129
bool has_nl_param(bool recursive=false) const
Whether this parameter store has any nonlinear parameter.
Definition Model.cxx:666
void diagnostic_check_output_variable(const VariableBase &v) const
Definition Model.cxx:249
virtual bool defines_values() const
Whether this model defines output values.
Definition Model.h:126
void forward_maybe_jit(bool out, bool dout, bool d2out)
Forward operator with jit.
Definition Model.cxx:407
virtual SecDerivMap d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's second derivative.
Definition Model.cxx:623
virtual void set_value(bool out, bool dout_din, bool d2out_din2)=0
The map between input -> output, and optionally its derivatives.
std::set< VariableName > provided_items() const override
The variables that this model defines as part of its output.
Definition Model.cxx:711
friend class ParameterStore
Declaration of nonlinear parameters may require manipulation of input.
Definition Model.h:219
virtual bool defines_second_derivatives() const
Whether this model defines second derivatives.
Definition Model.h:132
void diagnostic_assert_state(const VariableBase &v) const
Definition Model.cxx:194
void assemble(Res< false > *, Jac< false > *) override
Compute the unscaled residual and Jacobian.
Definition Model.cxx:756
void forward(bool out, bool dout, bool d2out)
Forward operator without jit.
Definition Model.cxx:378
virtual std::tuple< ValueMap, DerivMap, SecDerivMap > value_and_dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's value, first and second derivative.
Definition Model.cxx:573
virtual ValueMap value(const ValueMap &in)
Convenient shortcut to construct and return the model value.
Definition Model.cxx:505
std::string variable_name_lookup(const ATensor &var)
Look up the name of a variable in the traced graph.
Definition Model.cxx:462
T & register_model(const std::string &name, bool nonlinear=false, bool merge_input=true)
Register a model that the current model may use during its evaluation.
Definition Model.h:277
virtual std::map< std::string, NonlinearParameter > named_nonlinear_parameters(bool recursive=false) const
Get all nonlinear parameters.
Definition Model.cxx:685
void zero_input() override
Definition Model.cxx:342
void diagnostic_check_input_variable(const VariableBase &v) const
Definition Model.cxx:227
virtual bool is_nonlinear_system() const
Whether this model defines one or more nonlinear equations to be solved.
Definition Model.h:135
virtual std::tuple< DerivMap, SecDerivMap > dvalue_and_d2value(const ValueMap &in)
Convenient shortcut to construct and return the model's first and second derivative.
Definition Model.cxx:599
void diagnostic_assert_old_state(const VariableBase &v) const
Definition Model.cxx:200
virtual void link_output_variables()
Definition Model.cxx:283
jit::Stack collect_input_stack() const
Definition Model.cxx:736
void diagnostic_assert_old_force(const VariableBase &v) const
Definition Model.cxx:213
void setup() override
Setup this object.
Definition Model.cxx:140
void diagnose() const override
Check for common problems.
Definition Model.cxx:154
void clear_output() override
Definition Model.cxx:334
static OptionSet expected_options()
Definition Model.cxx:76
virtual DerivMap dvalue(const ValueMap &in)
Convenient shortcut to construct and return the derivative.
Definition Model.cxx:551
void assign_input_stack(jit::Stack &stack)
Definition Model.cxx:718
void diagnostic_assert_residual(const VariableBase &v) const
Definition Model.cxx:220
virtual void request_AD()
Definition Model.h:259
Model(const OptionSet &options)
Construct a new Model object.
Definition Model.cxx:110
void set_guess(const Sol< false > &) override
Set the unscaled current guess.
Definition Model.cxx:749
Model * registered_model(const std::string &name) const
Get a registered model by its name.
Definition Model.cxx:645
virtual bool is_jit_enabled() const
Whether JIT is enabled.
Definition Model.h:138
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
Definition Model.cxx:679
void diagnose_nl_sys() const
Additional diagnostics for a nonlinear system.
Definition Model.cxx:170
std::set< VariableName > consumed_items() const override
The variables that this model depends on.
Definition Model.cxx:704
void zero_output() override
Definition Model.cxx:350
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:42
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:74
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:99
NonlinearSystem(const NonlinearSystem &)=default
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:52
T & set(const std::string &)
Definition OptionSet.h:273
Base class of variable.
Definition Variable.h:52
VariableStore(Model *object)
Definition VariableStore.cxx:36
void assign_input(const ValueMap &vals)
Definition VariableStore.cxx:257
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.cxx:120
Definition DiagnosticsInterface.cxx:30
Model & load_model(const std::filesystem::path &path, const std::string &mname)
A convenient function to load an input file and get a model.
Definition Model.cxx:48
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
Model & get_model(const std::string &mname)
A convenient function to manufacture a neml2::Model.
Definition Model.cxx:41
at::Tensor ATensor
Definition types.h:42
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types_fwd.h:35
Model & reload_model(const std::filesystem::path &path, const std::string &mname)
Similar to neml2::load_model, but additionally clear the Factory before loading the model,...
Definition Model.cxx:55
c10::TensorOptions TensorOptions
Definition types.h:63
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:32
void check_precision()
Check the current default precision and warn if it's not double precision.
Definition Model.cxx:492
Schema for the traced forward operators.
Definition Model.h:105
bool operator==(const TraceSchema &other) const
Definition Model.cxx:62
std::vector< Size > batch_dims
Definition Model.h:106
at::DispatchKey dispatch_key
Definition Model.h:107
bool operator<(const TraceSchema &other) const
Definition Model.cxx:68
Nonlinear parameter.
Definition NonlinearParameter.h:49