27#include "neml2/base/NEML2Object.h"
28#include "neml2/base/Storage.h"
29#include "neml2/models/LabeledAxis.h"
30#include "neml2/models/Variable.h"
31#include "neml2/models/map_types.h"
32#include "neml2/tensors/tensors.h"
84 const torch::TensorOptions &
tensor_options()
const {
return _tensor_options; }
118 template <
typename T,
typename S>
122 if constexpr (!std::is_same_v<T, Tensor>)
124 "Creating a Variable of primitive tensor type does not require a base shape.");
126 const auto var_name = variable_name(std::forward<S>(
name));
130 const auto sz = list_sz * base_sz;
133 return *create_variable<T>(_input_variables, var_name, list_shape, base_shape);
137 template <
typename T,
typename S>
141 if constexpr (!std::is_same_v<T, Tensor>)
143 "Creating a Variable of primitive tensor type does not require a base shape.");
145 const auto var_name = variable_name(std::forward<S>(
name));
149 const auto sz = list_sz * base_sz;
152 return *create_variable<T>(_output_variables, var_name, list_shape, base_shape);
159 neml_assert(&
var.owner() != _object,
"Trying to clone a variable from the same model.");
163 !_input_variables.query_value(
var_name),
"Input variable ",
var_name,
" already exists.");
173 neml_assert(&
var.owner() != _object,
"Trying to clone a variable from the same model.");
177 !_output_variables.query_value(
var_name),
"Output variable ",
var_name,
" already exists.");
186 template <
typename S>
189 if constexpr (std::is_convertible_v<S, std::string>)
197 template <
typename T>
198 Variable<T> * create_variable(Storage<VariableName, VariableBase> & variables,
204 VariableBase * var_base_ptr = variables.query_value(
name);
206 "Trying to create variable ",
208 ", but a variable with the same name already exists.");
211 if constexpr (std::is_same_v<T, Tensor>)
213 auto var = std::make_unique<Variable<Tensor>>(
name, _object, list_shape, base_shape);
214 var_base_ptr = variables.set_pointer(
name, std::move(var));
218 auto var = std::make_unique<Variable<T>>(
name, _object, list_shape);
219 var_base_ptr = variables.set_pointer(
name, std::move(var));
223 auto var_ptr =
dynamic_cast<Variable<T> *
>(var_base_ptr);
225 var_ptr,
"Internal error: Failed to cast variable ",
name,
" to its concrete type.");
239 const OptionSet _object_options;
242 Storage<std::string, LabeledAxis> _axes;
245 LabeledAxis & _input_axis;
248 LabeledAxis & _output_axis;
251 Storage<VariableName, VariableBase> _input_variables;
254 Storage<VariableName, VariableBase> _output_variables;
257 torch::TensorOptions _tensor_options;
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:58
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:47
void add_variable(const LabeledAxisAccessor &name, Size sz)
Add a variable with known storage size.
Definition LabeledAxis.cxx:54
The base class for all constitutive models.
Definition Model.h:51
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:85
Base class of variable.
Definition Variable.h:47
Definition VariableStore.h:40
void assign_output_derivatives(const DerivMap &derivs)
Assign variable derivatives.
Definition VariableStore.cxx:138
VariableBase & output_variable(const VariableName &)
Definition VariableStore.cxx:75
Variable< T > & declare_output_variable(S &&name, TensorShapeRef list_shape={}, TensorShapeRef base_shape={})
Declare an output variable.
Definition VariableStore.h:139
virtual void zero_input()
Definition VariableStore.cxx:107
VariableBase * clone_output_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the output axis.
Definition VariableStore.h:171
ValueMap collect_output() const
Definition VariableStore.cxx:154
void assign_output(const ValueMap &vals)
Definition VariableStore.cxx:131
virtual ~VariableStore()=default
LabeledAxis & output_axis()
Definition VariableStore.h:63
SecDerivMap collect_output_second_derivatives() const
Collect variable second derivatives.
Definition VariableStore.cxx:172
Storage< VariableName, VariableBase > & input_variables()
Definition VariableStore.h:69
virtual void clear_input()
Definition VariableStore.cxx:91
Storage< VariableName, VariableBase > & output_variables()
Definition VariableStore.h:71
VariableStore(OptionSet options, Model *object)
Definition VariableStore.cxx:30
const Storage< VariableName, VariableBase > & input_variables() const
Definition VariableStore.h:70
const Storage< VariableName, VariableBase > & output_variables() const
Definition VariableStore.h:72
ValueMap collect_input() const
Definition VariableStore.cxx:145
VariableBase & input_variable(const VariableName &)
Definition VariableStore.cxx:59
VariableStore(const VariableStore &)=delete
virtual void zero_output()
Definition VariableStore.cxx:115
const torch::TensorOptions & tensor_options() const
Current tensor options.
Definition VariableStore.h:84
DerivMap collect_output_derivatives() const
Collect variable derivatives.
Definition VariableStore.cxx:163
const LabeledAxis & input_axis() const
Definition VariableStore.h:58
virtual void clear_output()
Definition VariableStore.cxx:99
virtual void setup_layout()
Setup the layout of all the registered axes.
Definition VariableStore.cxx:52
const Variable< T > & declare_input_variable(S &&name, TensorShapeRef list_shape={}, TensorShapeRef base_shape={})
Declare an input variable.
Definition VariableStore.h:120
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:40
VariableStore & operator=(const VariableStore &)=delete
void assign_input(const ValueMap &vals)
Definition VariableStore.cxx:123
VariableStore & operator=(VariableStore &&)=delete
const LabeledAxis & output_axis() const
Definition VariableStore.h:64
VariableStore(VariableStore &&)=delete
LabeledAxis & input_axis()
Definition VariableStore.h:57
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.h:156
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
Definition CrossRef.cxx:31
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types.h:34
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types.h:35
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
LabeledAxisAccessor VariableName
Definition parser_utils.h:33
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types.h:36
torch::IntArrayRef TensorShapeRef
Definition types.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64