NEML2 2.0.0
Loading...
Searching...
No Matches
Variable.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/models/LabeledAxisAccessor.h"
28#include "neml2/base/DependencyResolver.h"
29#include "neml2/tensors/Tensor.h"
30#include "neml2/models/map_types.h"
31
32namespace neml2
33{
34// Forward declarations
35class Model;
36class Derivative;
37enum class TensorType : int8_t;
38
47{
48public:
49 VariableBase() = default;
50
51 VariableBase(const VariableBase &) = delete;
53 VariableBase & operator=(const VariableBase &) = delete;
55 virtual ~VariableBase() = default;
56
58
60 const VariableName & name() const { return _name; }
61
64 const Model & owner() const;
65 Model & owner();
67
69 virtual TensorType type() const = 0;
70
73 bool is_state() const;
74 bool is_old_state() const;
75 bool is_force() const;
76 bool is_old_force() const;
77 bool is_residual() const;
78 bool is_parameter() const;
79 bool is_solve_dependent() const;
81 // Note that the check depends on whether we are currently solving nonlinear system
82 bool is_dependent() const;
84
86 // These methods mirror TensorBase
89 torch::TensorOptions options() const { return tensor().options(); }
91 torch::Dtype scalar_type() const { return tensor().scalar_type(); }
93 torch::Device device() const { return tensor().device(); }
95 Size dim() const { return tensor().dim(); }
97 TensorShapeRef sizes() const { return tensor().sizes(); }
99 Size size(Size dim) const { return tensor().size(dim); }
101 bool batched() const { return tensor().batched(); }
103 Size batch_dim() const { return tensor().batch_dim(); }
105 Size list_dim() const { return Size(list_sizes().size()); }
107 Size base_dim() const { return Size(base_sizes().size()); }
109 TraceableTensorShape batch_sizes() const { return tensor().batch_sizes(); }
111 TensorShapeRef list_sizes() const { return _list_sizes; }
113 virtual TensorShapeRef base_sizes() const = 0;
115 TraceableSize batch_size(Size dim) const { return tensor().batch_size(dim); }
117 Size base_size(Size dim) const { return base_sizes()[dim]; }
119 Size list_size(Size dim) const { return list_sizes()[dim]; }
128
130 virtual std::unique_ptr<VariableBase> clone(const VariableName & name = {},
131 Model * owner = nullptr) const = 0;
132
134 virtual void ref(const VariableBase & other, bool ref_is_mutable = false) = 0;
135
137 virtual const VariableBase * ref() const = 0;
138
140 virtual bool owning() const = 0;
141
143 virtual void zero(const torch::TensorOptions & options) = 0;
144
146 virtual void set(const Tensor & val) = 0;
147
149 virtual Tensor get() const = 0;
150
152 virtual Tensor tensor() const = 0;
153
155 bool requires_grad() const { return tensor().requires_grad(); }
156
158 virtual void requires_grad_(bool req = true) = 0;
159
161 virtual void operator=(const Tensor & val) = 0;
162
164 Derivative d(const VariableBase & var);
165
167 Derivative d(const VariableBase & var1, const VariableBase & var2);
168
171 void request_AD(const VariableBase & u);
172 void request_AD(const std::vector<const VariableBase *> & us);
174
177 void request_AD(const VariableBase & u1, const VariableBase & u2);
178 void request_AD(const std::vector<const VariableBase *> & u1s,
179 const std::vector<const VariableBase *> & u2s);
181
183 const ValueMap & derivatives() const { return _derivs; }
184 ValueMap & derivatives() { return _derivs; }
185
187 const DerivMap & second_derivatives() const { return _sec_derivs; }
188 DerivMap & second_derivatives() { return _sec_derivs; }
189
191 virtual void clear();
192
195
198
200 const VariableName _name = {};
201
203 Model * const _owner = nullptr;
204
205private:
206 ValueMap total_derivatives(const DependencyResolver<Model, VariableName> & dep,
207 Model * model,
208 const VariableName & yvar) const;
209
210 DerivMap total_second_derivatives(const DependencyResolver<Model, VariableName> & dep,
211 Model * model,
212 const VariableName & yvar) const;
213
215 const TensorShape _list_sizes = {};
216
218 ValueMap _derivs;
219
221 DerivMap _sec_derivs;
222};
223
228template <typename T>
229class Variable : public VariableBase
230{
231public:
232 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<Tensor, T2>>>
240
241 template <typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<Tensor, T2>>>
252
253 TensorType type() const override;
254
255 TensorShapeRef base_sizes() const override { return _base_sizes; }
256
257 std::unique_ptr<VariableBase> clone(const VariableName & name = {},
258 Model * owner = nullptr) const override;
259
260 void ref(const VariableBase & var, bool ref_is_mutable = false) override;
261
262 const VariableBase * ref() const override { return _ref ? _ref->ref() : this; }
263
264 bool owning() const override { return !_ref; }
265
266 void zero(const torch::TensorOptions & options) override;
267
268 void set(const Tensor & val) override;
269
270 Tensor get() const override { return tensor().base_flatten(); }
271
272 Tensor tensor() const override;
273
274 void requires_grad_(bool req = true) override;
275
276 void operator=(const Tensor & val) override;
277
279 const T & value() const { return owning() ? _value : _ref->value(); }
280
282 T operator-() const { return -value(); }
283
285 operator T() const { return value(); }
286
288 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, Tensor>>>
289 operator Tensor() const
290 {
291 return value();
292 }
293
294 void clear() override;
295
296protected:
299
302
305
308};
309
311{
312public:
314 : _base_sizes({}),
315 _deriv(nullptr)
316 {
317 }
318
320 : _base_sizes(base_sizes),
321 _deriv(deriv)
322 {
323 }
324
325 void operator=(const Tensor & val);
326
327private:
329 const TensorShape _base_sizes;
330
332 Tensor * const _deriv;
333};
334
335// Everything below is just for convenience: We just forward operations to the the variable values
336// so that we can do
337//
338// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
339// var4 = (var1 - var2) * var3
340// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
341//
342// instead of the (ugly?) expression below
343//
344// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
345// var4 = (var1.v - var2.v) * var3.v
346// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
347#define FWD_VARIABLE_BINARY_OP(op) \
348 template <typename T1, \
349 typename T2, \
350 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
351 std::is_base_of_v<VariableBase, T2>>> \
352 auto op(const T1 & a, const T2 & b) \
353 { \
354 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
355 return op(a.value(), b.value()); \
356 \
357 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
358 return op(a.value(), b); \
359 \
360 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
361 return op(a, b.value()); \
362 } \
363 static_assert(true)
364FWD_VARIABLE_BINARY_OP(operator+);
365FWD_VARIABLE_BINARY_OP(operator-);
366FWD_VARIABLE_BINARY_OP(operator*);
367FWD_VARIABLE_BINARY_OP(operator/);
368}
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
Definition Variable.h:311
Derivative()
Definition Variable.h:313
void operator=(const Tensor &val)
Definition Variable.cxx:421
Derivative(TensorShapeRef base_sizes, Tensor *deriv)
Definition Variable.h:319
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:58
The base class for all constitutive models.
Definition Model.h:51
neml2::Tensor base_flatten() const
Flatten base dimensions.
Definition TensorBaseImpl.h:413
Definition Tensor.h:47
Base class of variable.
Definition Variable.h:47
VariableBase(VariableBase &&)=delete
TraceableTensorShape batch_sizes() const
Return the batch shape.
Definition Variable.h:109
virtual ~VariableBase()=default
bool is_old_force() const
Definition Variable.cxx:71
Model *const _owner
The model which declared this variable.
Definition Variable.h:203
virtual void zero(const torch::TensorOptions &options)=0
Set the variable value to zero.
bool requires_grad() const
Check if this variable is part of the AD function graph.
Definition Variable.h:155
bool is_parameter() const
Definition Variable.cxx:83
bool is_state() const
Definition Variable.cxx:53
const Model & owner() const
Definition Variable.cxx:39
void apply_second_order_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply second order chain rule.
Definition Variable.cxx:182
Derivative d(const VariableBase &var)
Wrapper for assigning partial derivative.
Definition Variable.cxx:101
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable)
torch::Dtype scalar_type() const
Scalar type.
Definition Variable.h:91
Size base_storage() const
Base storage of the variable.
Definition Variable.h:121
bool batched() const
Whether the tensor is batched.
Definition Variable.h:101
Size batch_dim() const
Return the number of batch dimensions.
Definition Variable.h:103
void request_AD(const VariableBase &u)
Definition Variable.cxx:128
Size dim() const
Number of tensor dimensions.
Definition Variable.h:95
VariableBase()=default
TensorShapeRef sizes() const
Tensor shape.
Definition Variable.h:97
Size base_size(Size dim) const
Return the size of a base axis.
Definition Variable.h:117
TensorShapeRef list_sizes() const
Return the list shape.
Definition Variable.h:111
DerivMap & second_derivatives()
Definition Variable.h:188
Size list_dim() const
Return the number of list dimensions.
Definition Variable.h:105
torch::Device device() const
Device.
Definition Variable.h:93
virtual Tensor get() const =0
Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
virtual void ref(const VariableBase &other, bool ref_is_mutable=false)=0
Reference another variable.
torch::TensorOptions options() const
Definition Variable.h:89
virtual void set(const Tensor &val)=0
Set the variable value.
virtual std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const =0
Clone this variable.
bool is_solve_dependent() const
Definition Variable.cxx:89
VariableBase(const VariableBase &)=delete
bool is_residual() const
Definition Variable.cxx:77
const ValueMap & derivatives() const
Partial derivatives.
Definition Variable.h:183
ValueMap & derivatives()
Definition Variable.h:184
TraceableSize batch_size(Size dim) const
Return the size of a batch axis.
Definition Variable.h:115
VariableBase & operator=(VariableBase &&)=delete
const VariableName & name() const
Name of this variable.
Definition Variable.h:60
virtual void operator=(const Tensor &val)=0
Assignment operator.
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
Definition Variable.cxx:95
VariableBase & operator=(const VariableBase &)=delete
const VariableName _name
Name of the variable.
Definition Variable.h:200
virtual TensorShapeRef base_sizes() const =0
Return the base shape.
const DerivMap & second_derivatives() const
Partial second derivatives.
Definition Variable.h:187
Size size(Size dim) const
Size of a dimension.
Definition Variable.h:99
Size list_size(Size dim) const
Return the size of a list axis.
Definition Variable.h:119
virtual bool owning() const =0
Check if this is an owning variable.
virtual void clear()
Clear the variable value and derivatives.
Definition Variable.cxx:163
Size base_dim() const
Return the number of base dimensions.
Definition Variable.h:107
bool is_force() const
Definition Variable.cxx:65
virtual TensorType type() const =0
Variable tensor type.
void apply_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply first order chain rule.
Definition Variable.cxx:171
virtual Tensor tensor() const =0
Get the variable value.
Size assembly_storage() const
Assembly storage of the variable.
Definition Variable.h:123
bool is_old_state() const
Definition Variable.cxx:59
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Concrete definition of a variable.
Definition Variable.h:230
Tensor get() const override
Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
Definition Variable.h:270
void zero(const torch::TensorOptions &options) override
Set the variable value to zero.
Definition Variable.cxx:321
const Variable< T > * _ref
The variable referenced by this (nullptr if this is a storing variable)
Definition Variable.h:301
Tensor tensor() const override
Get the variable value.
Definition Variable.cxx:359
void requires_grad_(bool req=true) override
Mark this variable as a leaf variable in tracing function graph for AD.
Definition Variable.cxx:373
Variable(VariableName name_in, Model *owner, TensorShapeRef list_shape, TensorShapeRef base_shape)
Definition Variable.h:242
bool _ref_is_mutable
Whether mutating the referenced variable is allowed.
Definition Variable.h:304
void operator=(const Tensor &val) override
Assignment operator.
Definition Variable.cxx:384
const VariableBase * ref() const override
Get the referencing variable (returns this if this is a storing variable)
Definition Variable.h:262
T operator-() const
Negation.
Definition Variable.h:282
const TensorShape _base_sizes
Base shape of the variable.
Definition Variable.h:298
void set(const Tensor &val) override
Set the variable value.
Definition Variable.cxx:342
const T & value() const
Variable value.
Definition Variable.h:279
TensorType type() const override
Variable tensor type.
Definition Variable.cxx:270
bool owning() const override
Check if this is an owning variable.
Definition Variable.h:264
std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const override
Clone this variable.
Definition Variable.cxx:277
T _value
Variable value (undefined if this is a referencing variable)
Definition Variable.h:307
Variable(VariableName name_in, Model *owner, TensorShapeRef list_shape)
Definition Variable.h:233
TensorShapeRef base_sizes() const override
Return the base shape.
Definition Variable.h:255
void clear() override
Clear the variable value and derivatives.
Definition Variable.cxx:400
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:55
Definition CrossRef.cxx:31
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types.h:34
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types.h:35
torch::SmallVector< Size > TensorShape
Definition types.h:34
int64_t Size
Definition types.h:33
TensorType
Definition tensors.h:59
torch::IntArrayRef TensorShapeRef
Definition types.h:35
Traceable size.
Definition types.h:52
Traceable tensor shape.
Definition types.h:81