NEML2 2.0.0
Loading...
Searching...
No Matches
VariableBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/map_types_fwd.h"
28#include "neml2/models/utils.h"
29#include "neml2/base/LabeledAxisAccessor.h"
30#include "neml2/misc/types.h"
31#include "neml2/tensors/Tensor.h"
32
33namespace neml2
34{
35// Forward declarations
36class Model;
37template <std::size_t N>
39enum class TensorType : int8_t;
40template <typename, typename>
42struct TraceableSize;
44
53{
54public:
55 VariableBase() = default;
56
57 VariableBase(const VariableBase &) = delete;
59 VariableBase & operator=(const VariableBase &) = delete;
61 virtual ~VariableBase() = default;
62
64 Model * owner,
65 TensorShapeRef base_shape,
67
69 const VariableName & name() const { return _name; }
70
73 const Model & owner() const;
74 Model & owner();
76
78 virtual TensorType type() const = 0;
79
82 bool is_state() const;
83 bool is_old_state() const;
84 bool is_force() const;
85 bool is_old_force() const;
86 bool is_residual() const;
87 bool is_parameter() const;
88 bool is_solve_dependent() const;
90 // Note that the check depends on whether we are currently solving nonlinear system
91 bool is_dependent() const;
93
95 // These methods mirror TensorBase
98 virtual bool defined() const = 0;
100 virtual TensorOptions options() const = 0;
102 virtual Dtype scalar_type() const = 0;
104 virtual Device device() const = 0;
106
109 Size dim() const;
110 Size batch_dim() const;
111 Size base_dim() const;
112 Size dynamic_dim() const;
113 Size static_dim() const;
114 Size intmd_dim() const;
116
119 TensorShapeRef sizes() const;
122 virtual const TraceableTensorShape & dynamic_sizes() const = 0;
126
129 Size size(Size i) const;
131 Size base_size(Size i) const;
132 const TraceableSize & dynamic_size(Size i) const;
133 Size static_size(Size i) const;
134 Size intmd_size(Size i) const;
136
138 virtual std::unique_ptr<VariableBase> clone(const VariableName & name = {},
139 Model * owner = nullptr) const = 0;
140
142 virtual void ref(const VariableBase & other, bool ref_is_mutable = false) = 0;
143
145 virtual const VariableBase * ref() const = 0;
146
149
152
154 virtual bool owning() const = 0;
155
157 Tensor zeros(const TensorOptions & options) const;
158
160 virtual void zero(const TensorOptions & options) = 0;
161
163 virtual void set(const Tensor & val, std::optional<TracerPrivilege> key = std::nullopt) = 0;
164
166 virtual Tensor get() const = 0;
167
169 virtual Tensor tensor() const = 0;
170
172 bool requires_grad() const;
173
175 virtual void requires_grad_(bool req = true) = 0;
176
178 virtual void operator=(const Tensor & val) = 0;
179
181 bool has_derivative(const VariableName & vname) const;
182
184 bool has_derivative(const VariableName & v1name, const VariableName & v2name) const;
185
187 Derivative<1> & d(const VariableBase & var, ArrayRef<Size> dep_dims = {});
188 const Derivative<1> & d(const VariableBase & var) const;
189
192 d2(const VariableBase & var1, const VariableBase & var2, ArrayRef<Size> dep_dims = {});
193 const Derivative<2> & d2(const VariableBase & var1, const VariableBase & var2) const;
194
197 void request_AD(const VariableBase & u);
198 void request_AD(const std::vector<const VariableBase *> & us);
200
203 void request_AD(const VariableBase & u1, const VariableBase & u2);
204 void request_AD(const std::vector<const VariableBase *> & u1s,
205 const std::vector<const VariableBase *> & u2s);
207
209 const std::vector<Derivative<1>> & derivatives() const { return _derivs; }
210 std::vector<Derivative<1>> & derivatives() { return _derivs; }
211
213 const std::vector<Derivative<2>> & second_derivatives() const { return _sec_derivs; }
214 std::vector<Derivative<2>> & second_derivatives() { return _sec_derivs; }
215
217 virtual void clear();
218
220 void clear_derivatives();
221
224
227
229 const VariableName _name = {};
230
232 Model * const _owner = nullptr;
233
237
240
243
244private:
245 ValueMap total_derivatives(const DependencyResolver<Model, VariableName> & dep,
246 Model * model,
247 const VariableBase & yvar) const;
248
249 DerivMap total_second_derivatives(const DependencyResolver<Model, VariableName> & dep,
250 Model * model,
251 const VariableBase & yvar) const;
252
254 std::vector<Derivative<1>> _derivs;
255
257 std::vector<Derivative<2>> _sec_derivs;
258};
259
260// Everything below is just for convenience: We just forward operations to the the variable values
261// so that we can do
262//
263// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
264// var4 = (var1 - var2) * var3
265// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
266//
267// instead of the (ugly?) expression below
268//
269// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
270// var4 = (var1.v - var2.v) * var3.v
271// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
272#define FWD_VARIABLE_BINARY_OP(op) \
273 template <typename T1, \
274 typename T2, \
275 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
276 std::is_base_of_v<VariableBase, T2>>> \
277 auto op(const T1 & a, const T2 & b) \
278 { \
279 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
280 return op(a(), b()); \
281 \
282 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
283 return op(a(), b); \
284 \
285 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
286 return op(a, b()); \
287 } \
288 static_assert(true)
289
290FWD_VARIABLE_BINARY_OP(operator+);
291FWD_VARIABLE_BINARY_OP(operator-);
292FWD_VARIABLE_BINARY_OP(operator*);
293FWD_VARIABLE_BINARY_OP(operator/);
294
295FWD_VARIABLE_BINARY_OP(operator>);
296FWD_VARIABLE_BINARY_OP(operator<);
297FWD_VARIABLE_BINARY_OP(operator>=);
298FWD_VARIABLE_BINARY_OP(operator<=);
299FWD_VARIABLE_BINARY_OP(operator&&);
300FWD_VARIABLE_BINARY_OP(operator||);
301FWD_VARIABLE_BINARY_OP(operator==);
302FWD_VARIABLE_BINARY_OP(operator!=);
303}
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition VariableBase.h:41
Derivative wrapper.
Definition VariableBase.h:38
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:56
The base class for all constitutive models.
Definition Model.h:70
Definition Tensor.h:47
Base class of variable.
Definition VariableBase.h:53
VariableBase(VariableBase &&)=delete
TraceableTensorShape batch_sizes() const
Definition VariableBase.cxx:155
virtual ~VariableBase()=default
bool is_old_force() const
Definition VariableBase.cxx:83
Size static_size(Size i) const
Definition VariableBase.cxx:213
Model *const _owner
The model which declared this variable.
Definition VariableBase.h:232
const std::vector< Derivative< 1 > > & derivatives() const
Partial derivatives.
Definition VariableBase.h:209
Size intmd_size(Size i) const
Definition VariableBase.cxx:222
bool requires_grad() const
Check if this variable is part of the AD function graph.
Definition VariableBase.cxx:249
bool is_parameter() const
Definition VariableBase.cxx:95
bool is_state() const
Definition VariableBase.cxx:65
const Model & owner() const
Definition VariableBase.cxx:51
Size dynamic_dim() const
Definition VariableBase.cxx:131
void apply_second_order_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply second order chain rule.
Definition VariableBase.cxx:381
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable)
TensorShapeRef intmd_sizes() const
Definition VariableBase.cxx:173
const TensorShape _dep_intmd_dims
Dependent intermediate dimensions for derivative calculation.
Definition VariableBase.h:242
void clear_derivatives()
Clear only the derivatives.
Definition VariableBase.cxx:360
const TraceableSize & dynamic_size(Size i) const
Definition VariableBase.cxx:206
Tensor zeros(const TensorOptions &options) const
Make zeros tensor with the shape of this variable.
Definition VariableBase.cxx:243
Derivative< 2 > & d2(const VariableBase &var1, const VariableBase &var2, ArrayRef< Size > dep_dims={})
Wrapper for assigning second partial derivative.
Definition VariableBase.cxx:302
Size batch_dim() const
Definition VariableBase.cxx:119
void request_AD(const VariableBase &u)
Definition VariableBase.cxx:318
TraceableSize batch_size(Size i) const
Definition VariableBase.cxx:190
Size dim() const
Definition VariableBase.cxx:113
VariableBase()=default
TensorShapeRef sizes() const
Definition VariableBase.cxx:149
TensorShapeRef static_sizes() const
Definition VariableBase.cxx:167
virtual TensorOptions options() const =0
Tensor options.
virtual Tensor get() const =0
Get the variable value in assembly format.
virtual Device device() const =0
Device.
Derivative< 1 > & d(const VariableBase &var, ArrayRef< Size > dep_dims={})
Wrapper for assigning partial derivative.
Definition VariableBase.cxx:286
virtual void ref(const VariableBase &other, bool ref_is_mutable=false)=0
Reference another variable.
virtual std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const =0
Clone this variable.
Size base_size(Size i) const
Definition VariableBase.cxx:199
bool is_solve_dependent() const
Definition VariableBase.cxx:101
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
const std::vector< Derivative< 2 > > & second_derivatives() const
Partial second derivatives.
Definition VariableBase.h:213
VariableBase(const VariableBase &)=delete
bool is_residual() const
Definition VariableBase.cxx:89
virtual bool defined() const =0
const TensorShape _base_sizes
Base shape of the variable.
Definition VariableBase.h:239
VariableBase & operator=(VariableBase &&)=delete
std::vector< Derivative< 1 > > & derivatives()
Definition VariableBase.h:210
virtual void set(const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt)=0
Set the variable value from a Tensor in assembly format.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:69
virtual void operator=(const Tensor &val)=0
Assignment operator.
TensorShape _cached_intmd_sizes
Definition VariableBase.h:236
std::vector< Derivative< 2 > > & second_derivatives()
Definition VariableBase.h:214
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
Definition VariableBase.cxx:107
VariableBase & operator=(const VariableBase &)=delete
virtual Dtype scalar_type() const =0
Scalar type.
Size intmd_dim() const
Definition VariableBase.cxx:143
const VariableName _name
Name of the variable.
Definition VariableBase.h:229
virtual bool owning() const =0
Check if this is an owning variable.
virtual void clear()
Clear the variable value and derivatives.
Definition VariableBase.cxx:353
Size base_dim() const
Definition VariableBase.cxx:125
bool is_force() const
Definition VariableBase.cxx:77
ArrayRef< Size > dep_intmd_dims() const
Get dependent intermediate dimensions for derivative calculation.
Definition VariableBase.cxx:237
Size static_dim() const
Definition VariableBase.cxx:137
TensorShapeRef base_sizes() const
Definition VariableBase.cxx:161
virtual TensorType type() const =0
Variable tensor type.
void apply_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply first order chain rule.
Definition VariableBase.cxx:367
bool has_derivative(const VariableName &vname) const
Whether the variable has non-zero derivative with respect to another variable.
Definition VariableBase.cxx:255
void set_intmd_sizes(TensorShapeRef shape)
Set the intermediate shape.
Definition VariableBase.cxx:229
virtual Tensor tensor() const =0
Get the variable value cast to Tensor.
virtual const TraceableTensorShape & dynamic_sizes() const =0
bool is_old_state() const
Definition VariableBase.cxx:71
Size size(Size i) const
Definition VariableBase.cxx:179
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:63
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
c10::ArrayRef< T > ArrayRef
Definition types.h:59
int64_t Size
Definition types.h:65
TensorType
Definition tensors.h:56
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
c10::ScalarType Dtype
Definition types.h:61
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38