27#include "neml2/models/VariableBase.h"
36class Variable :
public VariableBase
63 Model *
owner =
nullptr)
const override;
65 void ref(
const VariableBase & var,
bool ref_is_mutable =
false)
override;
73 void set(
const Tensor & val, std::optional<TracerPrivilege> key)
override;
89 void clear()
override;
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:56
The base class for all constitutive models.
Definition Model.h:70
Base class of variable.
Definition VariableBase.h:53
const Model & owner() const
Definition VariableBase.cxx:51
virtual void ref(const VariableBase &other, bool ref_is_mutable=false)=0
Reference another variable.
const VariableName & name() const
Name of this variable.
Definition VariableBase.h:69
ArrayRef< Size > dep_intmd_dims() const
Get dependent intermediate dimensions for derivative calculation.
Definition VariableBase.cxx:237
Concrete definition of a variable.
Definition VariableStore.h:41
bool defined() const override
Definition Variable.cxx:45
Dtype scalar_type() const override
Scalar type.
Definition Variable.cxx:59
Tensor get() const override
Get the variable value in assembly format.
Definition Variable.cxx:167
const Variable< T > * _ref
The variable referenced by this (nullptr if this is a storing variable)
Definition Variable.h:93
Tensor tensor() const override
Get the variable value cast to Tensor.
Definition Variable.cxx:178
void requires_grad_(bool req=true) override
Mark this variable as a leaf variable in tracing function graph for AD.
Definition Variable.cxx:188
Variable(VariableName name_in, Model *owner, TensorShapeRef dep_intmd_dims={})
Definition Variable.h:39
bool _ref_is_mutable
Whether mutating the referenced variable is allowed.
Definition Variable.h:96
void operator=(const Tensor &val) override
Assignment operator.
Definition Variable.cxx:199
Device device() const override
Device.
Definition Variable.cxx:66
const VariableBase * ref() const override
Get the referencing variable (returns this if this is a storing variable)
Definition Variable.h:67
void zero(const TensorOptions &options) override
Set the variable value to zero.
Definition Variable.cxx:125
T operator-() const
Negation.
Definition Variable.h:87
TensorOptions options() const override
Tensor options.
Definition Variable.cxx:52
TensorType type() const override
Variable tensor type.
Definition Variable.cxx:38
const TraceableTensorShape & dynamic_sizes() const override
Definition Variable.cxx:73
bool owning() const override
Check if this is an owning variable.
Definition Variable.h:69
void set(const Tensor &val, std::optional< TracerPrivilege > key) override
Set the variable value from a Tensor in assembly format.
Definition Variable.cxx:146
std::unique_ptr< VariableBase > clone(const VariableName &name={}, Model *owner=nullptr) const override
Clone this variable.
Definition Variable.cxx:80
T _value
Variable value (undefined if this is a referencing variable)
Definition Variable.h:99
void clear() override
Clear the variable value and derivatives.
Definition Variable.cxx:231
const T & operator()() const
Variable value.
Definition Variable.h:84
Definition DiagnosticsInterface.cxx:30
c10::Device Device
Definition types.h:63
LabeledAxisAccessor VariableName
Definition LabeledAxisAccessor.h:185
TensorType
Definition tensors.h:56
c10::TensorOptions TensorOptions
Definition types.h:60
c10::ArrayRef< Size > TensorShapeRef
Definition types.h:67
c10::ScalarType Dtype
Definition types.h:61