NEML2 2.0.0
Loading...
Searching...
No Matches
Variable.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <memory>
28
29#include "neml2/models/map_types_fwd.h"
30#include "neml2/base/LabeledAxisAccessor.h"
31#include "neml2/misc/types.h"
32
33namespace neml2
34{
35// Forward declarations
36class Model;
37class Derivative;
38enum class TensorType : int8_t;
39template <typename, typename>
41struct TraceableSize;
43
52{
53public:
54 VariableBase() = default;
55
56 VariableBase(const VariableBase &) = delete;
58 VariableBase & operator=(const VariableBase &) = delete;
60 virtual ~VariableBase() = default;
61
62 VariableBase(VariableName name_in, Model * owner, TensorShapeRef list_shape);
63
65 const VariableName & name() const { return _name; }
66
69 const Model & owner() const;
70 Model & owner();
72
74 virtual TensorType type() const = 0;
75
78 bool is_state() const;
79 bool is_old_state() const;
80 bool is_force() const;
81 bool is_old_force() const;
82 bool is_residual() const;
83 bool is_parameter() const;
84 bool is_solve_dependent() const;
86 // Note that the check depends on whether we are currently solving nonlinear system
87 bool is_dependent() const;
89
91 // These methods mirror TensorBase
94 TensorOptions options() const;
96 Dtype scalar_type() const;
98 Device device() const;
100 Size dim() const;
102 TensorShapeRef sizes() const;
104 Size size(Size dim) const;
106 bool batched() const;
108 Size batch_dim() const;
110 Size list_dim() const;
112 Size base_dim() const;
118 virtual TensorShapeRef base_sizes() const = 0;
122 Size base_size(Size dim) const;
124 Size list_size(Size dim) const;
126 Size base_storage() const;
128 Size assembly_storage() const;
130
132 virtual std::unique_ptr<VariableBase> clone(const VariableName & name = {},
133 Model * owner = nullptr) const = 0;
134
136 virtual void ref(const VariableBase & other, bool ref_is_mutable = false) = 0;
137
139 virtual const VariableBase * ref() const = 0;
140
142 virtual bool owning() const = 0;
143
145 virtual void zero(const TensorOptions & options) = 0;
146
148 virtual void set(const Tensor & val) = 0;
149
152 virtual void set(const ATensor & val, bool force = false) = 0;
153
155 virtual Tensor get() const = 0;
156
158 virtual Tensor tensor() const = 0;
159
161 bool requires_grad() const;
162
164 virtual void requires_grad_(bool req = true) = 0;
165
167 virtual void operator=(const Tensor & val) = 0;
168
170 Derivative d(const VariableBase & var);
171
173 Derivative d(const VariableBase & var1, const VariableBase & var2);
174
177 void request_AD(const VariableBase & u);
178 void request_AD(const std::vector<const VariableBase *> & us);
180
183 void request_AD(const VariableBase & u1, const VariableBase & u2);
184 void request_AD(const std::vector<const VariableBase *> & u1s,
185 const std::vector<const VariableBase *> & u2s);
187
189 const ValueMap & derivatives() const { return _derivs; }
190 ValueMap & derivatives() { return _derivs; }
191
193 const DerivMap & second_derivatives() const { return _sec_derivs; }
194 DerivMap & second_derivatives() { return _sec_derivs; }
195
197 virtual void clear();
198
200 void clear_derivatives();
201
204
207
209 const VariableName _name = {};
210
212 Model * const _owner = nullptr;
213
214private:
215 ValueMap total_derivatives(const DependencyResolver<Model, VariableName> & dep,
216 Model * model,
217 const VariableName & yvar) const;
218
219 DerivMap total_second_derivatives(const DependencyResolver<Model, VariableName> & dep,
220 Model * model,
221 const VariableName & yvar) const;
222
224 const TensorShape _list_sizes = {};
225
227 ValueMap _derivs;
228
230 DerivMap _sec_derivs;
231};
232
237template <typename T>
238class Variable : public VariableBase
239{
240public:
241 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<Tensor, T2>>>
243 : VariableBase(std::move(name_in), owner, list_shape),
244 _base_sizes(T::const_base_sizes),
245 _ref(nullptr),
246 _ref_is_mutable(false)
247 {
248 }
249
250 template <typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<Tensor, T2>>>
252 Model * owner,
253 TensorShapeRef list_shape,
254 TensorShapeRef base_shape)
255 : VariableBase(std::move(name_in), owner, list_shape),
256 _base_sizes(base_shape),
257 _ref(nullptr),
258 _ref_is_mutable(false)
259 {
260 }
261
262 TensorType type() const override;
263
264 TensorShapeRef base_sizes() const override { return _base_sizes; }
265
266 std::unique_ptr<VariableBase> clone(const VariableName & name = {},
267 Model * owner = nullptr) const override;
268
269 void ref(const VariableBase & var, bool ref_is_mutable = false) override;
270
271 const VariableBase * ref() const override { return _ref ? _ref->ref() : this; }
272
273 bool owning() const override { return !_ref; }
274
275 void zero(const TensorOptions & options) override;
276
277 void set(const Tensor & val) override;
278
279 void set(const ATensor & val, bool force = false) override;
280
281 Tensor get() const override;
282
283 Tensor tensor() const override;
284
285 void requires_grad_(bool req = true) override;
286
287 void operator=(const Tensor & val) override;
288
290 const T & value() const { return owning() ? _value : _ref->value(); }
291
293 T operator-() const { return -value(); }
294
296 operator T() const { return value(); }
297
298 void clear() override;
299
300protected:
303
306
309
312};
313
315{
316public:
318 : _base_sizes({}),
319 _deriv(nullptr)
320 {
321 }
322
323 Derivative(TensorShapeRef base_sizes, Tensor * deriv)
324 : _base_sizes(base_sizes),
325 _deriv(deriv)
326 {
327 }
328
329 Derivative & operator=(const Tensor & val);
330
331 template <typename T>
333 {
334 return operator=(var.value());
335 }
336
337private:
339 const TensorShape _base_sizes;
340
342 Tensor * const _deriv;
343};
344
345// Everything below is just for convenience: We just forward operations to the the variable values
346// so that we can do
347//
348// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
349// var4 = (var1 - var2) * var3
350// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
351//
352// instead of the (ugly?) expression below
353//
354// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
355// var4 = (var1.v - var2.v) * var3.v
356// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
357#define FWD_VARIABLE_BINARY_OP(op) \
358 template <typename T1, \
359 typename T2, \
360 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
361 std::is_base_of_v<VariableBase, T2>>> \
362 auto op(const T1 & a, const T2 & b) \
363 { \
364 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
365 return op(a.value(), b.value()); \
366 \
367 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
368 return op(a.value(), b); \
369 \
370 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
371 return op(a, b.value()); \
372 } \
373 static_assert(true)
374FWD_VARIABLE_BINARY_OP(operator+);
375FWD_VARIABLE_BINARY_OP(operator-);
376FWD_VARIABLE_BINARY_OP(operator*);
377FWD_VARIABLE_BINARY_OP(operator/);
378}
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition Variable.h:40
Definition Variable.h:315
Derivative()
Definition Variable.h:317
Derivative & operator=(const Tensor &val)
Definition Variable.cxx:602
Derivative & operator=(const Variable< T > &var)
Definition Variable.h:332
Derivative(TensorShapeRef base_sizes, Tensor *deriv)
Definition Variable.h:323
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:56
The base class for all constitutive models.
Definition Model.h:68
Definition Tensor.h:46
Base class of variable.
Definition Variable.h:52
virtual void set(const ATensor &val, bool force=false)=0
VariableBase(VariableBase &&)=delete
Device device() const
Device.
Definition Variable.cxx:118
TraceableTensorShape batch_sizes() const
Return the batch shape.
Definition Variable.cxx:166
virtual ~VariableBase()=default
bool is_old_force() const
Definition Variable.cxx:76
Model *const _owner
The model which declared this variable.
Definition Variable.h:212
bool requires_grad() const
Check if this variable is part of the AD function graph.
Definition Variable.cxx:208
bool is_parameter() const
Definition Variable.cxx:88
bool is_state() const
Definition Variable.cxx:58
const Model & owner() const
Definition Variable.cxx:44
void apply_second_order_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply second order chain rule.
Definition Variable.cxx:301
Derivative d(const VariableBase &var)
Wrapper for assigning partial derivative.
Definition Variable.cxx:214
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable)
Dtype scalar_type() const
Scalar type.
Definition Variable.cxx:112
Size base_storage() const
Base storage of the variable.
Definition Variable.cxx:196
void clear_derivatives()
Clear only the derivatives.
Definition Variable.cxx:283
bool batched() const
Whether the tensor is batched.
Definition Variable.cxx:142
Size batch_dim() const
Return the number of batch dimensions.
Definition Variable.cxx:148
void request_AD(const VariableBase &u)
Definition Variable.cxx:241
Size dim() const
Number of tensor dimensions.
Definition Variable.cxx:124
VariableBase()=default
TensorShapeRef sizes() const
Tensor shape.
Definition Variable.cxx:130
Size base_size(Size dim) const
Return the size of a base axis.
Definition Variable.cxx:184
TensorShapeRef list_sizes() const
Return the list shape.
Definition Variable.cxx:172
DerivMap & second_derivatives()
Definition Variable.h:194
Size list_dim() const
Return the number of list dimensions.
Definition Variable.cxx:154
virtual Tensor get() const =0
Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
virtual void ref(const VariableBase &other, bool ref_is_mutable=false)=0
Reference another variable.
virtual void set(const Tensor &val)=0
Set the variable value.
virtual std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const =0
Clone this variable.
bool is_solve_dependent() const
Definition Variable.cxx:94
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
VariableBase(const VariableBase &)=delete
bool is_residual() const
Definition Variable.cxx:82
const ValueMap & derivatives() const
Partial derivatives.
Definition Variable.h:189
ValueMap & derivatives()
Definition Variable.h:190
TraceableSize batch_size(Size dim) const
Return the size of a batch axis.
Definition Variable.cxx:178
VariableBase & operator=(VariableBase &&)=delete
const VariableName & name() const
Name of this variable.
Definition Variable.h:65
virtual void operator=(const Tensor &val)=0
Assignment operator.
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
Definition Variable.cxx:100
VariableBase & operator=(const VariableBase &)=delete
const VariableName _name
Name of the variable.
Definition Variable.h:209
virtual TensorShapeRef base_sizes() const =0
Return the base shape.
const DerivMap & second_derivatives() const
Partial second derivatives.
Definition Variable.h:193
Size size(Size dim) const
Size of a dimension.
Definition Variable.cxx:136
Size list_size(Size dim) const
Return the size of a list axis.
Definition Variable.cxx:190
virtual bool owning() const =0
Check if this is an owning variable.
virtual void clear()
Clear the variable value and derivatives.
Definition Variable.cxx:276
Size base_dim() const
Return the number of base dimensions.
Definition Variable.cxx:160
bool is_force() const
Definition Variable.cxx:70
virtual TensorType type() const =0
Variable tensor type.
void apply_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply first order chain rule.
Definition Variable.cxx:290
virtual Tensor tensor() const =0
Get the variable value.
Size assembly_storage() const
Assembly storage of the variable.
Definition Variable.cxx:202
TensorOptions options() const
Definition Variable.cxx:106
bool is_old_state() const
Definition Variable.cxx:64
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Concrete definition of a variable.
Definition VariableStore.h:41
Tensor get() const override
Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
Definition Variable.cxx:523
const Variable< T > * _ref
The variable referenced by this (nullptr if this is a storing variable)
Definition Variable.h:305
Tensor tensor() const override
Get the variable value.
Definition Variable.cxx:530
void requires_grad_(bool req=true) override
Mark this variable as a leaf variable in tracing function graph for AD.
Definition Variable.cxx:544
Variable(VariableName name_in, Model *owner, TensorShapeRef list_shape, TensorShapeRef base_shape)
Definition Variable.h:251
bool _ref_is_mutable
Whether mutating the referenced variable is allowed.
Definition Variable.h:308
void operator=(const Tensor &val) override
Assignment operator.
Definition Variable.cxx:555
const VariableBase * ref() const override
Get the referencing variable (returns this if this is a storing variable)
Definition Variable.h:271
void zero(const TensorOptions &options) override
Set the variable value to zero.
Definition Variable.cxx:449
T operator-() const
Negation.
Definition Variable.h:293
const TensorShape _base_sizes
Base shape of the variable.
Definition Variable.h:302
void set(const Tensor &val) override
Set the variable value.
Definition Variable.cxx:475
const T & value() const
Variable value.
Definition Variable.h:290
TensorType type() const override
Variable tensor type.
Definition Variable.cxx:389
bool owning() const override
Check if this is an owning variable.
Definition Variable.h:273
std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const override
Clone this variable.
Definition Variable.cxx:396
T _value
Variable value (undefined if this is a referencing variable)
Definition Variable.h:311
Variable(VariableName name_in, Model *owner, TensorShapeRef list_shape)
Definition Variable.h:242
TensorShapeRef base_sizes() const override
Return the base shape.
Definition Variable.h:264
void clear() override
Clear the variable value and derivatives.
Definition Variable.cxx:576
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:63
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
int64_t Size
Definition types.h:65
TensorType
Definition tensors.h:61
at::Tensor ATensor
Definition types.h:38
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
c10::ScalarType Dtype
Definition types.h:61
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38