NEML2 2.1.0
Loading...
Searching...
No Matches
VariableBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/map_types_fwd.h"
28#include "neml2/models/utils.h"
29#include "neml2/base/LabeledAxisAccessor.h"
30#include "neml2/misc/types.h"
31#include "neml2/tensors/Tensor.h"
32
33namespace neml2
34{
35// Forward declarations
36class Model;
37template <std::size_t N>
38class Derivative;
39enum class TensorType : int8_t;
40template <typename, typename>
42struct TraceableSize;
44
53{
54public:
55 VariableBase() = default;
56
57 VariableBase(const VariableBase &) = delete;
59 VariableBase & operator=(const VariableBase &) = delete;
61 virtual ~VariableBase();
62
71
73 const VariableName & name() const { return _name; }
74
77 const Model & owner() const;
80
82 virtual TensorType type() const = 0;
83
86 bool is_state() const;
87 bool is_old_state() const;
88 bool is_force() const;
89 bool is_old_force() const;
90 bool is_residual() const;
91 bool is_parameter() const;
92 bool is_solve_dependent() const;
94 // Note that the check depends on whether we are currently solving nonlinear system
95 bool is_dependent() const;
97
99 // These methods mirror TensorBase
102 virtual bool defined() const = 0;
104 virtual TensorOptions options() const = 0;
106 virtual Dtype scalar_type() const = 0;
108 virtual Device device() const = 0;
110
113 Size dim() const;
115 Size base_dim() const;
120
126 virtual const TraceableTensorShape & dynamic_sizes() const = 0;
130
133 Size size(Size i) const;
140
142 virtual std::unique_ptr<VariableBase> clone(const VariableName & name = {},
143 Model * owner = nullptr) const = 0;
144
146 virtual void ref(VariableBase & other) = 0;
147
149 virtual const VariableBase * ref() const = 0;
150 virtual VariableBase * ref() = 0;
151
153 virtual const VariableBase * direct_ref() const = 0;
154 virtual VariableBase * direct_ref() = 0;
155
157 virtual bool owning() const = 0;
158
160 bool is_mutable() const;
161
163 void set_mutable(bool m);
164
167
169 virtual void zero(const TensorOptions & options) = 0;
170
172 virtual Tensor tensor() const = 0;
173
175 bool requires_grad() const;
176
178 virtual void requires_grad_(bool req = true) = 0;
179
181 virtual void assign(const Tensor & val,
182 [[maybe_unused]] std::optional<TracerPrivilege> key = std::nullopt) = 0;
183
185 virtual void operator=(const Tensor & val) = 0;
186
188 bool has_derivative(const VariableName & vname) const;
189
191 bool has_derivative(const VariableName & v1name, const VariableName & v2name) const;
192
195 std::size_t deriv_intrsc_intmd_dim = 0,
196 std::size_t var_intrsc_intmd_dim = 0,
197 std::size_t arg_intrsc_intmd_dim = 0);
198 const Derivative<1> & d(const VariableBase & arg) const;
199
202 const VariableBase & arg2,
203 std::size_t deriv_intrsc_intmd_dim = 0,
204 std::size_t var_intrsc_intmd_dim = 0,
205 std::size_t arg1_intrsc_intmd_dim = 0,
206 std::size_t arg2_intrsc_intmd_dim = 0);
207 const Derivative<2> & d2(const VariableBase & arg1, const VariableBase & arg2) const;
210 void request_AD(const VariableBase & u);
211 void request_AD(const std::vector<const VariableBase *> & us);
213
216 void request_AD(const VariableBase & u1, const VariableBase & u2);
217 void request_AD(const std::vector<const VariableBase *> & u1s,
218 const std::vector<const VariableBase *> & u2s);
220
221 using DerivTuple = std::tuple<Derivative<1>, const VariableBase *>;
222 using DerivContainer = std::vector<DerivTuple>;
223 using SecDerivTuple = std::tuple<Derivative<2>, const VariableBase *, const VariableBase *>;
224 using SecDerivContainer = std::vector<SecDerivTuple>;
225
227 const DerivContainer & derivatives() const { return _derivs; }
228 DerivContainer & derivatives() { return _derivs; }
229
231 const SecDerivContainer & second_derivatives() const { return _sec_derivs; }
232 SecDerivContainer & second_derivatives() { return _sec_derivs; }
233
235 virtual void clear();
236
239
248 const SecDerivContainer &
253
254protected:
256 const VariableName _name = {};
257
259 Model * const _owner = nullptr;
260
264
267
269 bool _mutable = false;
270
271private:
274 DerivContainer _derivs;
276 SecDerivContainer _sec_derivs;
278 mutable DerivContainer _total_derivs;
280 mutable SecDerivContainer _total_sec_derivs;
282};
283
284// Everything below is just for convenience: We just forward operations to the the variable values
285// so that we can do
286//
287// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
288// var4 = (var1 - var2) * var3
289// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
290//
291// instead of the (ugly?) expression below
292//
293// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
294// var4 = (var1.v - var2.v) * var3.v
295// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
296#define FWD_VARIABLE_BINARY_OP(op) \
297 template <typename T1, \
298 typename T2, \
299 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
300 std::is_base_of_v<VariableBase, T2>>> \
301 auto op(const T1 & a, const T2 & b) \
302 { \
303 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
304 return op(a(), b()); \
305 \
306 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
307 return op(a(), b); \
308 \
309 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
310 return op(a, b()); \
311 } \
312 static_assert(true)
313
314FWD_VARIABLE_BINARY_OP(operator+);
315FWD_VARIABLE_BINARY_OP(operator-);
316FWD_VARIABLE_BINARY_OP(operator*);
317FWD_VARIABLE_BINARY_OP(operator/);
318
319FWD_VARIABLE_BINARY_OP(operator>);
320FWD_VARIABLE_BINARY_OP(operator<);
321FWD_VARIABLE_BINARY_OP(operator>=);
322FWD_VARIABLE_BINARY_OP(operator<=);
323FWD_VARIABLE_BINARY_OP(operator&&);
324FWD_VARIABLE_BINARY_OP(operator||);
325FWD_VARIABLE_BINARY_OP(operator==);
326FWD_VARIABLE_BINARY_OP(operator!=);
327}
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition DependencyResolver.h:47
Derivative wrapper.
Definition Derivative.h:66
The base class for all constitutive models.
Definition Model.h:83
Definition Tensor.h:49
void request_AD(const VariableBase &u1, const VariableBase &u2)
const Derivative< 1 > & d(const VariableBase &arg) const
DerivContainer & derivatives()
Definition VariableBase.h:228
VariableBase(VariableBase &&)=delete
TraceableTensorShape batch_sizes() const
bool is_old_force() const
Size static_size(Size i) const
Model *const _owner
The model which declared this variable.
Definition VariableBase.h:259
Size intmd_size(Size i) const
virtual void ref(VariableBase &other)=0
Reference another variable.
VariableBase(VariableName name_in, Model *owner, TensorShapeRef base_shape)
The canonical constructor.
bool requires_grad() const
Check if this variable is part of the AD function graph.
bool is_parameter() const
bool is_state() const
bool _mutable
When referenced by another variable, whether to allow the referencing variable to mutate my value.
Definition VariableBase.h:269
const Model & owner() const
Size dynamic_dim() const
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable).
TensorShapeRef intmd_sizes() const
void clear_derivatives()
Clear only the derivatives.
const TraceableSize & dynamic_size(Size i) const
Tensor zeros(const TensorOptions &options) const
Make zeros tensor with the shape of this variable.
virtual ~VariableBase()
virtual VariableBase * ref()=0
Size batch_dim() const
void request_AD(const VariableBase &u)
TraceableSize batch_size(Size i) const
Size dim() const
VariableBase()=default
std::vector< DerivTuple > DerivContainer
Definition VariableBase.h:222
TensorShapeRef sizes() const
std::tuple< Derivative< 2 >, const VariableBase *, const VariableBase * > SecDerivTuple
Definition VariableBase.h:223
TensorShapeRef static_sizes() const
const VariableBase & provider(const DependencyResolver< Model, VariableName > &) const
Get the provider in the dependency graph.
virtual TensorOptions options() const =0
Tensor options.
virtual const VariableBase * direct_ref() const =0
Get the direct referencing variable (returns nullptr if this is a storing variable).
virtual Device device() const =0
Device.
virtual std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const =0
Clone this variable.
Size base_size(Size i) const
bool is_solve_dependent() const
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
VariableBase(const VariableBase &)=delete
bool is_residual() const
virtual VariableBase * direct_ref()=0
virtual bool defined() const =0
const DerivContainer & derivatives() const
Partial derivatives.
Definition VariableBase.h:227
const TensorShape _base_sizes
Base shape of the variable.
Definition VariableBase.h:266
VariableBase & operator=(VariableBase &&)=delete
const DerivContainer & total_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total derivatives with respect to leaf variables.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:73
bool is_mutable() const
Whether this variable is mutable when it is referenced by another variable.
virtual void clear()
Clear the variable value and derivatives.
virtual void operator=(const Tensor &val)=0
Assignment operator.
TensorShape _cached_intmd_sizes
Definition VariableBase.h:263
bool has_derivative(const VariableName &v1name, const VariableName &v2name) const
Whether the variable has non-zero second derivative with respect to another variable.
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
VariableBase & operator=(const VariableBase &)=delete
virtual Dtype scalar_type() const =0
Scalar type.
virtual void assign(const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt)=0
Assignment operator (with TracerPrivilege).
void set_mutable(bool m)
Allow/disable mutation of this variable when it is referenced by another variable.
Size intmd_dim() const
const VariableName _name
Name of the variable.
Definition VariableBase.h:256
SecDerivContainer & second_derivatives()
Definition VariableBase.h:232
void request_AD(const std::vector< const VariableBase * > &us)
virtual bool owning() const =0
Check if this is an owning variable.
Size base_dim() const
bool is_force() const
std::vector< SecDerivTuple > SecDerivContainer
Definition VariableBase.h:224
void clear_chain_rule_cache(const DependencyResolver< Model, VariableName > &) const
Clear chain rule cache.
void request_AD(const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
Size static_dim() const
TensorShapeRef base_sizes() const
Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg1_intrsc_intmd_dim=0, std::size_t arg2_intrsc_intmd_dim=0)
Wrapper for assigning second partial derivative.
virtual TensorType type() const =0
Variable tensor type.
const SecDerivContainer & total_second_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total second derivatives with respect to leaf variables.
Derivative< 1 > & d(const VariableBase &arg, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg_intrsc_intmd_dim=0)
Wrapper for assigning partial derivative.
bool has_derivative(const VariableName &vname) const
Whether the variable has non-zero derivative with respect to another variable.
bool is_leaf(const DependencyResolver< Model, VariableName > &) const
virtual Tensor tensor() const =0
Get the variable value cast to Tensor.
virtual const TraceableTensorShape & dynamic_sizes() const =0
const SecDerivContainer & second_derivatives() const
Partial second derivatives.
Definition VariableBase.h:231
std::tuple< Derivative< 1 >, const VariableBase * > DerivTuple
Definition VariableBase.h:221
bool is_old_state() const
Size size(Size i) const
const Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2) const
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Definition DiagnosticsInterface.h:31
c10::Device Device
Definition types.h:69
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
int64_t Size
Definition types.h:71
TensorType
Definition tensors.h:56
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73
c10::ScalarType Dtype
Definition types.h:67
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38