NEML2 2.1.0
Loading...
Searching...
No Matches
VariableBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <optional>
28
29#include "neml2/models/utils.h"
30#include "neml2/misc/types.h"
31#include "neml2/tensors/Tensor.h"
32
33namespace neml2
34{
35// Forward declarations
36class Model;
37template <std::size_t N>
38class Derivative;
39enum class TensorType : int8_t;
40template <typename, typename>
42struct TraceableSize;
44
47std::pair<VariableName, std::size_t> parse_history(const VariableName & name,
48 const std::string & sep);
49
58{
59public:
60 VariableBase() = default;
61
62 VariableBase(const VariableBase &) = delete;
64 VariableBase & operator=(const VariableBase &) = delete;
66 virtual ~VariableBase();
67
76
78 const VariableName & name() const { return _name; }
79
81 std::size_t history_order() const { return _history_order; }
82
84 const VariableName & base_name() const { return _base_name; }
85
88 const Model & owner() const;
91
93 virtual TensorType type() const = 0;
94
96 // These methods mirror TensorBase
99 virtual bool defined() const = 0;
101 virtual TensorOptions options() const = 0;
103 virtual Dtype scalar_type() const = 0;
105 virtual Device device() const = 0;
107
110 Size dim() const;
112 Size base_dim() const;
117
123 virtual const TraceableTensorShape & dynamic_sizes() const = 0;
127
130 Size size(Size i) const;
137
139 virtual std::unique_ptr<VariableBase> clone(std::optional<VariableName> name = std::nullopt,
140 Model * owner = nullptr) const = 0;
141
143 virtual void ref(VariableBase & other) = 0;
144
146 virtual const VariableBase * ref() const = 0;
147 virtual VariableBase * ref() = 0;
148
150 virtual const VariableBase * direct_ref() const = 0;
151 virtual VariableBase * direct_ref() = 0;
152
154 virtual bool owning() const = 0;
155
157 bool is_mutable() const;
158
160 void set_mutable(bool m);
161
164
166 virtual void zero(const TensorOptions & options) = 0;
167
169 virtual Tensor tensor() const = 0;
170
172 bool requires_grad() const;
173
175 virtual void requires_grad_(bool req = true) = 0;
176
178 virtual void assign(const Tensor & val,
179 [[maybe_unused]] std::optional<TracerPrivilege> key = std::nullopt) = 0;
180
182 virtual void operator=(const Tensor & val) = 0;
183
185 bool has_derivative(const VariableName & vname) const;
186
188 bool has_derivative(const VariableName & v1name, const VariableName & v2name) const;
189
192 std::size_t deriv_intrsc_intmd_dim = 0,
193 std::size_t var_intrsc_intmd_dim = 0,
194 std::size_t arg_intrsc_intmd_dim = 0);
195 const Derivative<1> & d(const VariableBase & arg) const;
196
199 const VariableBase & arg2,
200 std::size_t deriv_intrsc_intmd_dim = 0,
201 std::size_t var_intrsc_intmd_dim = 0,
202 std::size_t arg1_intrsc_intmd_dim = 0,
203 std::size_t arg2_intrsc_intmd_dim = 0);
204 const Derivative<2> & d2(const VariableBase & arg1, const VariableBase & arg2) const;
207 void request_AD(const VariableBase & u);
208 void request_AD(const std::vector<const VariableBase *> & us);
210
213 void request_AD(const VariableBase & u1, const VariableBase & u2);
214 void request_AD(const std::vector<const VariableBase *> & u1s,
215 const std::vector<const VariableBase *> & u2s);
217
218 using DerivTuple = std::tuple<Derivative<1>, const VariableBase *>;
219 using DerivContainer = std::vector<DerivTuple>;
220 using SecDerivTuple = std::tuple<Derivative<2>, const VariableBase *, const VariableBase *>;
221 using SecDerivContainer = std::vector<SecDerivTuple>;
222
224 const DerivContainer & derivatives() const { return _derivs; }
225 DerivContainer & derivatives() { return _derivs; }
226
228 const SecDerivContainer & second_derivatives() const { return _sec_derivs; }
229 SecDerivContainer & second_derivatives() { return _sec_derivs; }
230
232 virtual void clear();
233
236
245 const SecDerivContainer &
250
251protected:
253 const VariableName _name = {};
254
256 Model * const _owner = nullptr;
257
259 std::size_t _history_order = 0;
260
263
267
270
272 bool _mutable = false;
273
274private:
277 DerivContainer _derivs;
279 SecDerivContainer _sec_derivs;
281 mutable DerivContainer _total_derivs;
283 mutable SecDerivContainer _total_sec_derivs;
285};
286
287// Everything below is just for convenience: We just forward operations to the the variable values
288// so that we can do
289//
290// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
291// var4 = (var1 - var2) * var3
292// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
293//
294// instead of the (ugly?) expression below
295//
296// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
297// var4 = (var1.v - var2.v) * var3.v
298// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
299#define FWD_VARIABLE_BINARY_OP(op) \
300 template <typename T1, \
301 typename T2, \
302 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
303 std::is_base_of_v<VariableBase, T2>>> \
304 auto op(const T1 & a, const T2 & b) \
305 { \
306 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
307 return op(a(), b()); \
308 \
309 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
310 return op(a(), b); \
311 \
312 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
313 return op(a, b()); \
314 } \
315 static_assert(true)
316
317FWD_VARIABLE_BINARY_OP(operator+);
318FWD_VARIABLE_BINARY_OP(operator-);
319FWD_VARIABLE_BINARY_OP(operator*);
320FWD_VARIABLE_BINARY_OP(operator/);
321
322FWD_VARIABLE_BINARY_OP(operator>);
323FWD_VARIABLE_BINARY_OP(operator<);
324FWD_VARIABLE_BINARY_OP(operator>=);
325FWD_VARIABLE_BINARY_OP(operator<=);
326FWD_VARIABLE_BINARY_OP(operator&&);
327FWD_VARIABLE_BINARY_OP(operator||);
328FWD_VARIABLE_BINARY_OP(operator==);
329FWD_VARIABLE_BINARY_OP(operator!=);
330}
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition DependencyResolver.h:47
Derivative wrapper.
Definition Derivative.h:66
The base class for all constitutive models.
Definition Model.h:82
Definition Tensor.h:53
void request_AD(const VariableBase &u1, const VariableBase &u2)
const Derivative< 1 > & d(const VariableBase &arg) const
DerivContainer & derivatives()
Definition VariableBase.h:225
VariableBase(VariableBase &&)=delete
TraceableTensorShape batch_sizes() const
const VariableName & base_name() const
Base name without the history suffix (e.g. "stress" for "stress~1").
Definition VariableBase.h:84
Size static_size(Size i) const
Model *const _owner
The model which declared this variable.
Definition VariableBase.h:256
Size intmd_size(Size i) const
virtual void ref(VariableBase &other)=0
Reference another variable.
VariableBase(VariableName name_in, Model *owner, TensorShapeRef base_shape)
The canonical constructor.
bool requires_grad() const
Check if this variable is part of the AD function graph.
bool _mutable
When referenced by another variable, whether to allow the referencing variable to mutate my value.
Definition VariableBase.h:272
const Model & owner() const
Size dynamic_dim() const
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable).
TensorShapeRef intmd_sizes() const
std::size_t _history_order
History order parsed from the variable name (0 = current, 1 = "foo~1", etc.).
Definition VariableBase.h:259
void clear_derivatives()
Clear only the derivatives.
const TraceableSize & dynamic_size(Size i) const
Tensor zeros(const TensorOptions &options) const
Make zeros tensor with the shape of this variable.
std::size_t history_order() const
History order: 0 for current variables, 1 for "foo~1", 2 for "foo~2", etc.
Definition VariableBase.h:81
virtual ~VariableBase()
virtual std::unique_ptr< VariableBase > clone(std::optional< VariableName > name=std::nullopt, Model *owner=nullptr) const =0
Clone this variable.
virtual VariableBase * ref()=0
Size batch_dim() const
void request_AD(const VariableBase &u)
TraceableSize batch_size(Size i) const
Size dim() const
VariableBase()=default
std::vector< DerivTuple > DerivContainer
Definition VariableBase.h:219
TensorShapeRef sizes() const
std::tuple< Derivative< 2 >, const VariableBase *, const VariableBase * > SecDerivTuple
Definition VariableBase.h:220
TensorShapeRef static_sizes() const
const VariableBase & provider(const DependencyResolver< Model, VariableName > &) const
Get the provider in the dependency graph.
virtual TensorOptions options() const =0
Tensor options.
virtual const VariableBase * direct_ref() const =0
Get the direct referencing variable (returns nullptr if this is a storing variable).
virtual Device device() const =0
Device.
VariableName _base_name
Base name without history suffix (equals _name when history_order == 0).
Definition VariableBase.h:262
Size base_size(Size i) const
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
VariableBase(const VariableBase &)=delete
virtual VariableBase * direct_ref()=0
virtual bool defined() const =0
const DerivContainer & derivatives() const
Partial derivatives.
Definition VariableBase.h:224
const TensorShape _base_sizes
Base shape of the variable.
Definition VariableBase.h:269
VariableBase & operator=(VariableBase &&)=delete
const DerivContainer & total_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total derivatives with respect to leaf variables.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:78
bool is_mutable() const
Whether this variable is mutable when it is referenced by another variable.
virtual void clear()
Clear the variable value and derivatives.
virtual void operator=(const Tensor &val)=0
Assignment operator.
TensorShape _cached_intmd_sizes
Definition VariableBase.h:266
bool has_derivative(const VariableName &v1name, const VariableName &v2name) const
Whether the variable has non-zero second derivative with respect to another variable.
VariableBase & operator=(const VariableBase &)=delete
virtual Dtype scalar_type() const =0
Scalar type.
virtual void assign(const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt)=0
Assignment operator (with TracerPrivilege).
void set_mutable(bool m)
Allow/disable mutation of this variable when it is referenced by another variable.
Size intmd_dim() const
const VariableName _name
Name of the variable.
Definition VariableBase.h:253
SecDerivContainer & second_derivatives()
Definition VariableBase.h:229
void request_AD(const std::vector< const VariableBase * > &us)
virtual bool owning() const =0
Check if this is an owning variable.
Size base_dim() const
std::vector< SecDerivTuple > SecDerivContainer
Definition VariableBase.h:221
void clear_chain_rule_cache(const DependencyResolver< Model, VariableName > &) const
Clear chain rule cache.
void request_AD(const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
Size static_dim() const
TensorShapeRef base_sizes() const
Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg1_intrsc_intmd_dim=0, std::size_t arg2_intrsc_intmd_dim=0)
Wrapper for assigning second partial derivative.
virtual TensorType type() const =0
Variable tensor type.
const SecDerivContainer & total_second_derivatives(const DependencyResolver< Model, VariableName > &) const
Get total second derivatives with respect to leaf variables.
Derivative< 1 > & d(const VariableBase &arg, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg_intrsc_intmd_dim=0)
Wrapper for assigning partial derivative.
bool has_derivative(const VariableName &vname) const
Whether the variable has non-zero derivative with respect to another variable.
bool is_leaf(const DependencyResolver< Model, VariableName > &) const
virtual Tensor tensor() const =0
Get the variable value cast to Tensor.
virtual const TraceableTensorShape & dynamic_sizes() const =0
const SecDerivContainer & second_derivatives() const
Partial second derivatives.
Definition VariableBase.h:228
std::tuple< Derivative< 1 >, const VariableBase * > DerivTuple
Definition VariableBase.h:218
Size size(Size i) const
const Derivative< 2 > & d2(const VariableBase &arg1, const VariableBase &arg2) const
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Definition DiagnosticsInterface.h:31
c10::Device Device
Definition types.h:69
std::pair< VariableName, std::size_t > parse_history(const VariableName &name, const std::string &sep)
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:72
std::string name(ElasticConstant p)
int64_t Size
Definition types.h:71
std::string VariableName
Definition types.h:75
TensorType
Definition tensors.h:56
c10::TensorOptions TensorOptions
Definition types.h:66
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:73
c10::ScalarType Dtype
Definition types.h:67
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38