NEML2 2.0.0
Loading...
Searching...
No Matches
VariableStore.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/base/NEML2Object.h"
28#include "neml2/base/Storage.h"
29#include "neml2/models/LabeledAxis.h"
30#include "neml2/models/Variable.h"
31#include "neml2/models/map_types.h"
32#include "neml2/tensors/tensors.h"
33
34namespace neml2
35{
36// Foward decl
37class Model;
38
40{
41public:
42 VariableStore(OptionSet options, Model * object);
43
44 VariableStore(const VariableStore &) = delete;
48 virtual ~VariableStore() = default;
49
50 LabeledAxis & declare_axis(const std::string & name);
51
53 virtual void setup_layout();
54
57 LabeledAxis & input_axis() { return _input_axis; }
58 const LabeledAxis & input_axis() const { return _input_axis; }
60
63 LabeledAxis & output_axis() { return _output_axis; }
64 const LabeledAxis & output_axis() const { return _output_axis; }
66
70 const Storage<VariableName, VariableBase> & input_variables() const { return _input_variables; }
71 Storage<VariableName, VariableBase> & output_variables() { return _output_variables; }
72 const Storage<VariableName, VariableBase> & output_variables() const { return _output_variables; }
74
78 const VariableBase & input_variable(const VariableName &) const;
80 const VariableBase & output_variable(const VariableName &) const;
82
84 const torch::TensorOptions & tensor_options() const { return _tensor_options; }
85
88 virtual void clear_input();
89 virtual void clear_output();
91
94 virtual void zero_input();
95 virtual void zero_output();
97
100 void assign_input(const ValueMap & vals);
101 void assign_output(const ValueMap & vals);
105
108 ValueMap collect_input() const;
109 ValueMap collect_output() const;
115
116protected:
118 template <typename T, typename S>
119 const Variable<T> &
121 {
122 if constexpr (!std::is_same_v<T, Tensor>)
123 neml_assert(base_shape.empty(),
124 "Creating a Variable of primitive tensor type does not require a base shape.");
125
126 const auto var_name = variable_name(std::forward<S>(name));
127 const auto list_sz = utils::storage_size(list_shape);
128 const auto base_sz =
129 std::is_same_v<T, Tensor> ? utils::storage_size(base_shape) : T::const_base_storage;
130 const auto sz = list_sz * base_sz;
131
132 _input_axis.add_variable(var_name, sz);
133 return *create_variable<T>(_input_variables, var_name, list_shape, base_shape);
134 }
135
137 template <typename T, typename S>
138 Variable<T> &
140 {
141 if constexpr (!std::is_same_v<T, Tensor>)
142 neml_assert(base_shape.empty(),
143 "Creating a Variable of primitive tensor type does not require a base shape.");
144
145 const auto var_name = variable_name(std::forward<S>(name));
146 const auto list_sz = utils::storage_size(list_shape);
147 const auto base_sz =
148 std::is_same_v<T, Tensor> ? utils::storage_size(base_shape) : T::const_base_storage;
149 const auto sz = list_sz * base_sz;
150
151 _output_axis.add_variable(var_name, sz);
152 return *create_variable<T>(_output_variables, var_name, list_shape, base_shape);
153 }
154
157 const VariableName & new_name = {})
158 {
159 neml_assert(&var.owner() != _object, "Trying to clone a variable from the same model.");
160
161 const auto var_name = new_name.empty() ? var.name() : new_name;
163 !_input_variables.query_value(var_name), "Input variable ", var_name, " already exists.");
164 auto var_clone = var.clone(var_name, _object);
165
166 _input_axis.add_variable(var_name, var_clone->assembly_storage());
167 return _input_variables.set_pointer(var_name, std::move(var_clone));
168 }
169
172 {
173 neml_assert(&var.owner() != _object, "Trying to clone a variable from the same model.");
174
175 const auto var_name = new_name.empty() ? var.name() : new_name;
177 !_output_variables.query_value(var_name), "Output variable ", var_name, " already exists.");
178 auto var_clone = var.clone(var_name, _object);
179
180 _output_axis.add_variable(var_name, var_clone->assembly_storage());
181 return _output_variables.set_pointer(var_name, std::move(var_clone));
182 }
183
184private:
185 // Helper method to construct variable name
186 template <typename S>
187 VariableName variable_name(S && name) const
188 {
189 if constexpr (std::is_convertible_v<S, std::string>)
190 if (_object_options.contains<VariableName>(name))
191 return _object_options.get<VariableName>(name);
192
193 return name;
194 }
195
196 // Create a variable
197 template <typename T>
198 Variable<T> * create_variable(Storage<VariableName, VariableBase> & variables,
199 const VariableName & name,
200 TensorShapeRef list_shape,
201 TensorShapeRef base_shape)
202 {
203 // Make sure we don't duplicate variables
204 VariableBase * var_base_ptr = variables.query_value(name);
205 neml_assert(!var_base_ptr,
206 "Trying to create variable ",
207 name,
208 ", but a variable with the same name already exists.");
209
210 // Allocate
211 if constexpr (std::is_same_v<T, Tensor>)
212 {
213 auto var = std::make_unique<Variable<Tensor>>(name, _object, list_shape, base_shape);
214 var_base_ptr = variables.set_pointer(name, std::move(var));
215 }
216 else
217 {
218 auto var = std::make_unique<Variable<T>>(name, _object, list_shape);
219 var_base_ptr = variables.set_pointer(name, std::move(var));
220 }
221
222 // Cast it to the concrete type
223 auto var_ptr = dynamic_cast<Variable<T> *>(var_base_ptr);
225 var_ptr, "Internal error: Failed to cast variable ", name, " to its concrete type.");
226
227 return var_ptr;
228 }
229
231 Model * _object;
232
239 const OptionSet _object_options;
240
242 Storage<std::string, LabeledAxis> _axes;
243
245 LabeledAxis & _input_axis;
246
248 LabeledAxis & _output_axis;
249
251 Storage<VariableName, VariableBase> _input_variables;
252
254 Storage<VariableName, VariableBase> _output_variables;
255
257 torch::TensorOptions _tensor_options;
258};
259} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:58
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:47
void add_variable(const LabeledAxisAccessor &name, Size sz)
Add a variable with known storage size.
Definition LabeledAxis.cxx:54
The base class for all constitutive models.
Definition Model.h:51
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:85
Base class of variable.
Definition Variable.h:47
Definition VariableStore.h:40
void assign_output_derivatives(const DerivMap &derivs)
Assign variable derivatives.
Definition VariableStore.cxx:138
VariableBase & output_variable(const VariableName &)
Definition VariableStore.cxx:75
Variable< T > & declare_output_variable(S &&name, TensorShapeRef list_shape={}, TensorShapeRef base_shape={})
Declare an output variable.
Definition VariableStore.h:139
virtual void zero_input()
Definition VariableStore.cxx:107
VariableBase * clone_output_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the output axis.
Definition VariableStore.h:171
ValueMap collect_output() const
Definition VariableStore.cxx:154
void assign_output(const ValueMap &vals)
Definition VariableStore.cxx:131
virtual ~VariableStore()=default
LabeledAxis & output_axis()
Definition VariableStore.h:63
SecDerivMap collect_output_second_derivatives() const
Collect variable second derivatives.
Definition VariableStore.cxx:172
Storage< VariableName, VariableBase > & input_variables()
Definition VariableStore.h:69
virtual void clear_input()
Definition VariableStore.cxx:91
Storage< VariableName, VariableBase > & output_variables()
Definition VariableStore.h:71
VariableStore(OptionSet options, Model *object)
Definition VariableStore.cxx:30
const Storage< VariableName, VariableBase > & input_variables() const
Definition VariableStore.h:70
const Storage< VariableName, VariableBase > & output_variables() const
Definition VariableStore.h:72
ValueMap collect_input() const
Definition VariableStore.cxx:145
VariableBase & input_variable(const VariableName &)
Definition VariableStore.cxx:59
VariableStore(const VariableStore &)=delete
virtual void zero_output()
Definition VariableStore.cxx:115
const torch::TensorOptions & tensor_options() const
Current tensor options.
Definition VariableStore.h:84
DerivMap collect_output_derivatives() const
Collect variable derivatives.
Definition VariableStore.cxx:163
const LabeledAxis & input_axis() const
Definition VariableStore.h:58
virtual void clear_output()
Definition VariableStore.cxx:99
virtual void setup_layout()
Setup the layout of all the registered axes.
Definition VariableStore.cxx:52
const Variable< T > & declare_input_variable(S &&name, TensorShapeRef list_shape={}, TensorShapeRef base_shape={})
Declare an input variable.
Definition VariableStore.h:120
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:40
VariableStore & operator=(const VariableStore &)=delete
void assign_input(const ValueMap &vals)
Definition VariableStore.cxx:123
VariableStore & operator=(VariableStore &&)=delete
const LabeledAxis & output_axis() const
Definition VariableStore.h:64
VariableStore(VariableStore &&)=delete
LabeledAxis & input_axis()
Definition VariableStore.h:57
const VariableBase * clone_input_variable(const VariableBase &var, const VariableName &new_name={})
Clone a variable and put it on the input axis.
Definition VariableStore.h:156
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
Definition CrossRef.cxx:31
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types.h:34
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types.h:35
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
LabeledAxisAccessor VariableName
Definition parser_utils.h:33
std::map< LabeledAxisAccessor, DerivMap > SecDerivMap
Definition map_types.h:36
torch::IntArrayRef TensorShapeRef
Definition types.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64