27#include "neml2/models/map_types_fwd.h"
28#include "neml2/models/utils.h"
29#include "neml2/base/LabeledAxisAccessor.h"
30#include "neml2/misc/types.h"
31#include "neml2/tensors/Tensor.h"
37template <std::
size_t N>
40template <
typename,
typename>
163 virtual void set(
const Tensor & val, std::optional<TracerPrivilege> key = std::nullopt) = 0;
198 void request_AD(
const std::vector<const VariableBase *> & us);
204 void request_AD(
const std::vector<const VariableBase *> & u1s,
205 const std::vector<const VariableBase *> & u2s);
209 const std::vector<Derivative<1>> &
derivatives()
const {
return _derivs; }
217 virtual void clear();
254 std::vector<Derivative<1>> _derivs;
257 std::vector<Derivative<2>> _sec_derivs;
272#define FWD_VARIABLE_BINARY_OP(op) \
273 template <typename T1, \
275 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
276 std::is_base_of_v<VariableBase, T2>>> \
277 auto op(const T1 & a, const T2 & b) \
279 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
280 return op(a(), b()); \
282 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
285 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
290FWD_VARIABLE_BINARY_OP(
operator+);
291FWD_VARIABLE_BINARY_OP(
operator-);
292FWD_VARIABLE_BINARY_OP(
operator*);
293FWD_VARIABLE_BINARY_OP(
operator/);
295FWD_VARIABLE_BINARY_OP(
operator>);
296FWD_VARIABLE_BINARY_OP(
operator<);
297FWD_VARIABLE_BINARY_OP(
operator>=);
298FWD_VARIABLE_BINARY_OP(
operator<=);
299FWD_VARIABLE_BINARY_OP(
operator&&);
300FWD_VARIABLE_BINARY_OP(
operator||);
301FWD_VARIABLE_BINARY_OP(
operator==);
302FWD_VARIABLE_BINARY_OP(
operator!=);
The DependencyResolver identifies and resolves the dependencies among a set of objects derived from D...
Definition VariableBase.h:41
Derivative wrapper.
Definition VariableBase.h:38
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:56
The base class for all constitutive models.
Definition Model.h:70
Base class of variable.
Definition VariableBase.h:53
VariableBase(VariableBase &&)=delete
TraceableTensorShape batch_sizes() const
Definition VariableBase.cxx:155
virtual ~VariableBase()=default
bool is_old_force() const
Definition VariableBase.cxx:83
Size static_size(Size i) const
Definition VariableBase.cxx:213
Model *const _owner
The model which declared this variable.
Definition VariableBase.h:232
const std::vector< Derivative< 1 > > & derivatives() const
Partial derivatives.
Definition VariableBase.h:209
Size intmd_size(Size i) const
Definition VariableBase.cxx:222
bool requires_grad() const
Check if this variable is part of the AD function graph.
Definition VariableBase.cxx:249
bool is_parameter() const
Definition VariableBase.cxx:95
bool is_state() const
Definition VariableBase.cxx:65
const Model & owner() const
Definition VariableBase.cxx:51
Size dynamic_dim() const
Definition VariableBase.cxx:131
void apply_second_order_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply second order chain rule.
Definition VariableBase.cxx:381
virtual const VariableBase * ref() const =0
Get the referencing variable (returns this if this is a storing variable)
TensorShapeRef intmd_sizes() const
Definition VariableBase.cxx:173
const TensorShape _dep_intmd_dims
Dependent intermediate dimensions for derivative calculation.
Definition VariableBase.h:242
void clear_derivatives()
Clear only the derivatives.
Definition VariableBase.cxx:360
const TraceableSize & dynamic_size(Size i) const
Definition VariableBase.cxx:206
Tensor zeros(const TensorOptions &options) const
Make zeros tensor with the shape of this variable.
Definition VariableBase.cxx:243
Derivative< 2 > & d2(const VariableBase &var1, const VariableBase &var2, ArrayRef< Size > dep_dims={})
Wrapper for assigning second partial derivative.
Definition VariableBase.cxx:302
Size batch_dim() const
Definition VariableBase.cxx:119
void request_AD(const VariableBase &u)
Definition VariableBase.cxx:318
TraceableSize batch_size(Size i) const
Definition VariableBase.cxx:190
Size dim() const
Definition VariableBase.cxx:113
TensorShapeRef sizes() const
Definition VariableBase.cxx:149
TensorShapeRef static_sizes() const
Definition VariableBase.cxx:167
virtual TensorOptions options() const =0
Tensor options.
virtual Tensor get() const =0
Get the variable value in assembly format.
virtual Device device() const =0
Device.
Derivative< 1 > & d(const VariableBase &var, ArrayRef< Size > dep_dims={})
Wrapper for assigning partial derivative.
Definition VariableBase.cxx:286
virtual void ref(const VariableBase &other, bool ref_is_mutable=false)=0
Reference another variable.
virtual std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const =0
Clone this variable.
Size base_size(Size i) const
Definition VariableBase.cxx:199
bool is_solve_dependent() const
Definition VariableBase.cxx:101
virtual void zero(const TensorOptions &options)=0
Set the variable value to zero.
const std::vector< Derivative< 2 > > & second_derivatives() const
Partial second derivatives.
Definition VariableBase.h:213
VariableBase(const VariableBase &)=delete
bool is_residual() const
Definition VariableBase.cxx:89
virtual bool defined() const =0
const TensorShape _base_sizes
Base shape of the variable.
Definition VariableBase.h:239
VariableBase & operator=(VariableBase &&)=delete
std::vector< Derivative< 1 > > & derivatives()
Definition VariableBase.h:210
virtual void set(const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt)=0
Set the variable value from a Tensor in assembly format.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:69
virtual void operator=(const Tensor &val)=0
Assignment operator.
TensorShape _cached_intmd_sizes
Definition VariableBase.h:236
std::vector< Derivative< 2 > > & second_derivatives()
Definition VariableBase.h:214
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
Definition VariableBase.cxx:107
VariableBase & operator=(const VariableBase &)=delete
virtual Dtype scalar_type() const =0
Scalar type.
Size intmd_dim() const
Definition VariableBase.cxx:143
const VariableName _name
Name of the variable.
Definition VariableBase.h:229
virtual bool owning() const =0
Check if this is an owning variable.
virtual void clear()
Clear the variable value and derivatives.
Definition VariableBase.cxx:353
Size base_dim() const
Definition VariableBase.cxx:125
bool is_force() const
Definition VariableBase.cxx:77
ArrayRef< Size > dep_intmd_dims() const
Get dependent intermediate dimensions for derivative calculation.
Definition VariableBase.cxx:237
Size static_dim() const
Definition VariableBase.cxx:137
TensorShapeRef base_sizes() const
Definition VariableBase.cxx:161
virtual TensorType type() const =0
Variable tensor type.
void apply_chain_rule(const DependencyResolver< Model, VariableName > &)
Apply first order chain rule.
Definition VariableBase.cxx:367
bool has_derivative(const VariableName &vname) const
Whether the variable has non-zero derivative with respect to another variable.
Definition VariableBase.cxx:255
void set_intmd_sizes(TensorShapeRef shape)
Set the intermediate shape.
Definition VariableBase.cxx:229
virtual Tensor tensor() const =0
Get the variable value cast to Tensor.
virtual const TraceableTensorShape & dynamic_sizes() const =0
bool is_old_state() const
Definition VariableBase.cxx:71
Size size(Size i) const
Definition VariableBase.cxx:179
virtual void requires_grad_(bool req=true)=0
Mark this variable as a leaf variable in tracing function graph for AD.
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:63
c10::SmallVector< Size, 8 > TensorShape
Definition types.h:66
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types_fwd.h:33
std::map< LabeledAxisAccessor, ValueMap > DerivMap
Definition map_types_fwd.h:34
c10::ArrayRef< T > ArrayRef
Definition types.h:59
int64_t Size
Definition types.h:65
TensorType
Definition tensors.h:56
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
c10::ScalarType Dtype
Definition types.h:61
Traceable size.
Definition TraceableSize.h:42
Traceable tensor shape.
Definition TraceableTensorShape.h:38