NEML2 2.0.0
|
Concrete definition of a variable. More...
Concrete definition of a variable.
#include <Variable.h>
Public Member Functions | |
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<Tensor, T2>>> | |
Variable (VariableName name_in, Model *owner, TensorShapeRef list_shape) | |
template<typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<Tensor, T2>>> | |
Variable (VariableName name_in, Model *owner, TensorShapeRef list_shape, TensorShapeRef base_shape) | |
TensorType | type () const override |
Variable tensor type. | |
TensorShapeRef | base_sizes () const override |
Return the base shape. | |
std::unique_ptr< VariableBase > | clone (const VariableName &name={}, Model *owner=nullptr) const override |
Clone this variable. | |
void | ref (const VariableBase &var, bool ref_is_mutable=false) override |
Reference another variable. | |
const VariableBase * | ref () const override |
Get the referencing variable (returns this if this is a storing variable) | |
bool | owning () const override |
Check if this is an owning variable. | |
void | zero (const torch::TensorOptions &options) override |
Set the variable value to zero. | |
void | set (const Tensor &val) override |
Set the variable value. | |
Tensor | get () const override |
Get the variable value (with flattened base dimensions, i.e., for assembly purposes) | |
Tensor | tensor () const override |
Get the variable value. | |
void | requires_grad_ (bool req=true) override |
Mark this variable as a leaf variable in tracing function graph for AD. | |
void | operator= (const Tensor &val) override |
Assignment operator. | |
const T & | value () const |
Variable value. | |
T | operator- () const |
Negation. | |
operator T () const | |
Convert to the underlying tensor type. | |
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, Tensor>>> | |
operator Tensor () const | |
Convert to Tensor. | |
void | clear () override |
Clear the variable value and derivatives. | |
Public Member Functions inherited from VariableBase | |
VariableBase ()=default | |
VariableBase (const VariableBase &)=delete | |
VariableBase (VariableBase &&)=delete | |
VariableBase & | operator= (const VariableBase &)=delete |
VariableBase & | operator= (VariableBase &&)=delete |
virtual | ~VariableBase ()=default |
VariableBase (VariableName name_in, Model *owner, TensorShapeRef list_shape) | |
const VariableName & | name () const |
Name of this variable. | |
bool | requires_grad () const |
Check if this variable is part of the AD function graph. | |
Derivative | d (const VariableBase &var) |
Wrapper for assigning partial derivative. | |
Derivative | d (const VariableBase &var1, const VariableBase &var2) |
Wrapper for assigning second partial derivative. | |
const ValueMap & | derivatives () const |
Partial derivatives. | |
ValueMap & | derivatives () |
const DerivMap & | second_derivatives () const |
Partial second derivatives. | |
DerivMap & | second_derivatives () |
void | apply_chain_rule (const DependencyResolver< Model, VariableName > &) |
Apply first order chain rule. | |
void | apply_second_order_chain_rule (const DependencyResolver< Model, VariableName > &) |
Apply second order chain rule. | |
const Model & | owner () const |
Model & | owner () |
bool | is_state () const |
bool | is_old_state () const |
bool | is_force () const |
bool | is_old_force () const |
bool | is_residual () const |
bool | is_parameter () const |
bool | is_solve_dependent () const |
bool | is_dependent () const |
Check if the derivative with respect to this variable should be evaluated. | |
torch::TensorOptions | options () const |
torch::Dtype | scalar_type () const |
Scalar type. | |
torch::Device | device () const |
Device. | |
Size | dim () const |
Number of tensor dimensions. | |
TensorShapeRef | sizes () const |
Tensor shape. | |
Size | size (Size dim) const |
Size of a dimension. | |
bool | batched () const |
Whether the tensor is batched. | |
Size | batch_dim () const |
Return the number of batch dimensions. | |
Size | list_dim () const |
Return the number of list dimensions. | |
Size | base_dim () const |
Return the number of base dimensions. | |
TraceableTensorShape | batch_sizes () const |
Return the batch shape. | |
TensorShapeRef | list_sizes () const |
Return the list shape. | |
TraceableSize | batch_size (Size dim) const |
Return the size of a batch axis. | |
Size | base_size (Size dim) const |
Return the size of a base axis. | |
Size | list_size (Size dim) const |
Return the size of a list axis. | |
Size | base_storage () const |
Base storage of the variable. | |
Size | assembly_storage () const |
Assembly storage of the variable. | |
void | request_AD (const VariableBase &u) |
void | request_AD (const std::vector< const VariableBase * > &us) |
void | request_AD (const VariableBase &u1, const VariableBase &u2) |
void | request_AD (const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s) |
Protected Attributes | |
const TensorShape | _base_sizes |
Base shape of the variable. | |
const Variable< T > * | _ref |
The variable referenced by this (nullptr if this is a storing variable) | |
bool | _ref_is_mutable |
Whether mutating the referenced variable is allowed. | |
T | _value |
Variable value (undefined if this is a referencing variable) | |
Additional Inherited Members | |
Public Attributes inherited from VariableBase | |
const VariableName | _name = {} |
Name of the variable. | |
Model *const | _owner = nullptr |
The model which declared this variable. | |
|
inline |
|
inline |
|
inlineoverridevirtual |
Return the base shape.
Implements VariableBase.
Clear the variable value and derivatives.
Reimplemented from VariableBase.
|
overridevirtual |
Clone this variable.
Implements VariableBase.
Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
Implements VariableBase.
Assignment operator.
Implements VariableBase.
Check if this is an owning variable.
Implements VariableBase.
|
inlineoverridevirtual |
Get the referencing variable (returns this if this is a storing variable)
Implements VariableBase.
|
overridevirtual |
Reference another variable.
Implements VariableBase.
Mark this variable as a leaf variable in tracing function graph for AD.
Implements VariableBase.
Set the variable value.
Implements VariableBase.
Get the variable value.
Implements VariableBase.
|
overridevirtual |
Variable tensor type.
Implements VariableBase.
Set the variable value to zero.
Implements VariableBase.
|
protected |
Base shape of the variable.
The variable referenced by this (nullptr if this is a storing variable)
Whether mutating the referenced variable is allowed.