29#include "neml2/models/map_types_fwd.h"
30#include "neml2/base/LabeledAxisAccessor.h"
31#include "neml2/misc/types.h"
39template <
typename,
typename>
152 virtual void set(
const ATensor & val,
bool force =
false) = 0;
178 void request_AD(
const std::vector<const VariableBase *> & us);
184 void request_AD(
const std::vector<const VariableBase *> & u1s,
185 const std::vector<const VariableBase *> & u2s);
197 virtual void clear();
238 template <
typename T2 = T,
typename =
typename std::enable_if_t<!std::is_same_v<Tensor, T2>>>
247 template <
typename T2 = T,
typename =
typename std::enable_if_t<std::is_same_v<Tensor, T2>>>
264 Model *
owner =
nullptr)
const override;
266 void ref(
const VariableBase & var,
bool ref_is_mutable =
false)
override;
276 void set(
const ATensor & val,
bool force =
false)
override;
293 operator T()
const {
return value(); }
295 void clear()
override;
321 : _base_sizes(base_sizes),
328 template <
typename T>
354#define FWD_VARIABLE_BINARY_OP(op) \
355 template <typename T1, \
357 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
358 std::is_base_of_v<VariableBase, T2>>> \
359 auto op(const T1 & a, const T2 & b) \
361 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
362 return op(a.value(), b.value()); \
364 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
365 return op(a.value(), b); \
367 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
368 return op(a, b.value()); \
371FWD_VARIABLE_BINARY_OP(
operator+);
372FWD_VARIABLE_BINARY_OP(
operator-);
373FWD_VARIABLE_BINARY_OP(
operator*);
374FWD_VARIABLE_BINARY_OP(
operator/);
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition DependencyResolver.h:46
Definition Variable.h:312
Derivative()
Definition Variable.h:314
Derivative & operator=(const Tensor &val)
Definition Variable.cxx:588
Derivative & operator=(const Variable< T > &var)
Definition Variable.h:329
Derivative(TensorShapeRef base_sizes, Tensor *deriv)
Definition Variable.h:320
The base class for all constitutive models.
Definition Model.h:97
virtual void set(const ATensor &val, bool force=false)=0
VariableBase(VariableBase &&)=delete
Device device() const
Device.
Definition Variable.cxx:119
TraceableTensorShape batch_sizes() const
Return the batch shape.
Definition Variable.cxx:167
virtual ~VariableBase()=default
bool is_old_force() const
Definition Variable.cxx:77
Model *const _owner
The model which declared this variable.
Definition Variable.h:209
bool requires_grad() const
Check if this variable is part of the AD function graph.
Definition Variable.cxx:209
bool is_parameter() const
Definition Variable.cxx:89
bool is_state() const
Definition Variable.cxx:59
const Model & owner() const
Definition Variable.cxx:45
void apply_second_order_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply second order chain rule.
Definition Variable.cxx:296
Derivative d(const VariableBase &var)
Wrapper for assigning partial derivative.
Definition Variable.cxx:215
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable)
Dtype scalar_type() const
Scalar type.
Definition Variable.cxx:113
Size base_storage() const
Base storage of the variable.
Definition Variable.cxx:197
bool batched() const
Whether the tensor is batched.
Definition Variable.cxx:143
Size batch_dim() const
Return the number of batch dimensions.
Definition Variable.cxx:149
void request_AD(const VariableBase &u)
Definition Variable.cxx:242
Size dim() const
Number of tensor dimensions.
Definition Variable.cxx:125
TensorShapeRef sizes() const
Tensor shape.
Definition Variable.cxx:131
Size base_size(Size dim) const
Return the size of a base axis.
Definition Variable.cxx:185
TensorShapeRef list_sizes() const
Return the list shape.
Definition Variable.cxx:173
DerivMap & second_derivatives()
Definition Variable.h:194
Size list_dim() const
Return the number of list dimensions.
Definition Variable.cxx:155
virtual Tensor get() const =0
Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
virtual void ref(const VariableBase &other, bool ref_is_mutable=false)=0
Reference another variable.
virtual void set(const Tensor &val)=0
Set the variable value.
virtual std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const =0
Clone this variable.
bool is_solve_dependent() const
Definition Variable.cxx:95
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
VariableBase(const VariableBase &)=delete
bool is_residual() const
Definition Variable.cxx:83
const ValueMap & derivatives() const
Partial derivatives.
Definition Variable.h:189
ValueMap & derivatives()
Definition Variable.h:190
TraceableSize batch_size(Size dim) const
Return the size of a batch axis.
Definition Variable.cxx:179
VariableBase & operator=(VariableBase &&)=delete
const VariableName & name() const
Name of this variable.
Definition Variable.h:65
virtual void operator=(const Tensor &val)=0
Assignment operator.
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
Definition Variable.cxx:101
VariableBase & operator=(const VariableBase &)=delete
const VariableName _name
Name of the variable.
Definition Variable.h:206
virtual TensorShapeRef base_sizes() const =0
Return the base shape.
const DerivMap & second_derivatives() const
Partial second derivatives.
Definition Variable.h:193
Size size(Size dim) const
Size of a dimension.
Definition Variable.cxx:137
Size list_size(Size dim) const
Return the size of a list axis.
Definition Variable.cxx:191
virtual bool owning() const =0
Check if this is an owning variable.
virtual void clear()
Clear the variable value and derivatives.
Definition Variable.cxx:277
Size base_dim() const
Return the number of base dimensions.
Definition Variable.cxx:161
bool is_force() const
Definition Variable.cxx:71
virtual TensorType type() const =0
Variable tensor type.
void apply_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply first order chain rule.
Definition Variable.cxx:285
virtual Tensor tensor() const =0
Get the variable value.
Size assembly_storage() const
Assembly storage of the variable.
Definition Variable.cxx:203
TensorOptions options() const
Definition Variable.cxx:107
bool is_old_state() const
Definition Variable.cxx:65
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Concrete definition of a variable.
Definition Variable.h:236
Tensor get() const override
Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
Definition Variable.cxx:509
const Variable< T > * _ref
The variable referenced by this (nullptr if this is a storing variable)
Definition Variable.h:302
Tensor tensor() const override
Get the variable value.
Definition Variable.cxx:516
void requires_grad_(bool req=true) override
Mark this variable as a leaf variable in tracing function graph for AD.
Definition Variable.cxx:530
Variable(VariableName name_in, Model *owner, TensorShapeRef list_shape, TensorShapeRef base_shape)
Definition Variable.h:248
bool _ref_is_mutable
Whether mutating the referenced variable is allowed.
Definition Variable.h:305
void operator=(const Tensor &val) override
Assignment operator.
Definition Variable.cxx:541
const VariableBase * ref() const override
Get the referencing variable (returns this if this is a storing variable)
Definition Variable.h:268
void zero(const TensorOptions &options) override
Set the variable value to zero.
Definition Variable.cxx:435
T operator-() const
Negation.
Definition Variable.h:290
const TensorShape _base_sizes
Base shape of the variable.
Definition Variable.h:299
void set(const Tensor &val) override
Set the variable value.
Definition Variable.cxx:461
const T & value() const
Variable value.
Definition Variable.h:287
TensorType type() const override
Variable tensor type.
Definition Variable.cxx:384
bool owning() const override
Check if this is an owning variable.
Definition Variable.h:270
std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const override
Clone this variable.
Definition Variable.cxx:391
T _value
Variable value (undefined if this is a referencing variable)
Definition Variable.h:308
Variable(VariableName name_in, Model *owner, TensorShapeRef list_shape)
Definition Variable.h:239
TensorShapeRef base_sizes() const override
Return the base shape.
Definition Variable.h:261
void clear() override
Clear the variable value and derivatives.
Definition Variable.cxx:562
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:66
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:71
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
at::Tensor ATensor
Definition types.h:42
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
std::string name(ElasticConstant p)
Definition ElasticityConverter.cxx:30
int64_t Size
Definition types.h:69
TensorType
Definition tensors.h:61
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
c10::TensorOptions TensorOptions
Definition types.h:63
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:72
c10::ScalarType Dtype
Definition types.h:64
Traceable size.
Definition TraceableSize.h:40
Traceable tensor shape.
Definition TraceableTensorShape.h:38