NEML2 2.1.0
Loading...
Searching...
No Matches
Variable< T > Class Template Reference

Concrete definition of a variable. More...

Detailed Description

template<typename T>
class neml2::Variable< T >

Concrete definition of a variable.

#include <Variable.h>

Inheritance diagram for Variable< T >:

Public Member Functions

 Variable (VariableName name_in, Model *owner)
TensorType type () const override
 Variable tensor type.
const TraceableTensorShapedynamic_sizes () const override
std::unique_ptr< VariableBaseclone (const VariableName &name={}, Model *owner=nullptr) const override
 Clone this variable.
void ref (VariableBase &var) override
 Reference another variable.
const VariableBaseref () const override
 Get the referencing variable (returns this if this is a storing variable).
VariableBaseref () override
const VariableBasedirect_ref () const override
 Get the direct referencing variable (returns nullptr if this is a storing variable).
VariableBasedirect_ref () override
bool owning () const override
 Check if this is an owning variable.
void zero (const TensorOptions &options) override
 Set the variable value to zero.
Tensor tensor () const override
 Get the variable value cast to Tensor.
void requires_grad_ (bool req=true) override
 Mark this variable as a leaf variable in tracing function graph for AD.
void assign (const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt) override
 Assignment operator (with TracerPrivilege).
void operator= (const Tensor &val) override
 Assignment operator.
const T & operator() () const
 Variable value.
operator- () const
 Negation.
void clear () override
 Clear the variable value and derivatives.
Tensor information
bool defined () const override
TensorOptions options () const override
 Tensor options.
Dtype scalar_type () const override
 Scalar type.
Device device () const override
 Device.
Public Member Functions inherited from VariableBase
 VariableBase ()=default
 VariableBase (const VariableBase &)=delete
 VariableBase (VariableBase &&)=delete
VariableBaseoperator= (const VariableBase &)=delete
VariableBaseoperator= (VariableBase &&)=delete
virtual ~VariableBase ()
 VariableBase (VariableName name_in, Model *owner, TensorShapeRef base_shape)
 The canonical constructor.
const VariableNamename () const
 Name of this variable.
bool is_mutable () const
 Whether this variable is mutable when it is referenced by another variable.
void set_mutable (bool m)
 Allow/disable mutation of this variable when it is referenced by another variable.
Tensor zeros (const TensorOptions &options) const
 Make zeros tensor with the shape of this variable.
bool requires_grad () const
 Check if this variable is part of the AD function graph.
bool has_derivative (const VariableName &vname) const
 Whether the variable has non-zero derivative with respect to another variable.
bool has_derivative (const VariableName &v1name, const VariableName &v2name) const
 Whether the variable has non-zero second derivative with respect to another variable.
Derivative< 1 > & d (const VariableBase &arg, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg_intrsc_intmd_dim=0)
 Wrapper for assigning partial derivative.
const Derivative< 1 > & d (const VariableBase &arg) const
Derivative< 2 > & d2 (const VariableBase &arg1, const VariableBase &arg2, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg1_intrsc_intmd_dim=0, std::size_t arg2_intrsc_intmd_dim=0)
 Wrapper for assigning second partial derivative.
const Derivative< 2 > & d2 (const VariableBase &arg1, const VariableBase &arg2) const
const DerivContainerderivatives () const
 Partial derivatives.
DerivContainerderivatives ()
const SecDerivContainersecond_derivatives () const
 Partial second derivatives.
SecDerivContainersecond_derivatives ()
void clear_derivatives ()
 Clear only the derivatives.
const Modelowner () const
Modelowner ()
bool is_state () const
bool is_old_state () const
bool is_force () const
bool is_old_force () const
bool is_residual () const
bool is_parameter () const
bool is_solve_dependent () const
bool is_dependent () const
 Check if the derivative with respect to this variable should be evaluated.
Size dim () const
Size batch_dim () const
Size base_dim () const
Size dynamic_dim () const
Size static_dim () const
Size intmd_dim () const
TensorShapeRef sizes () const
TraceableTensorShape batch_sizes () const
TensorShapeRef base_sizes () const
TensorShapeRef static_sizes () const
TensorShapeRef intmd_sizes () const
Size size (Size i) const
TraceableSize batch_size (Size i) const
Size base_size (Size i) const
const TraceableSizedynamic_size (Size i) const
Size static_size (Size i) const
Size intmd_size (Size i) const
void request_AD (const VariableBase &u)
void request_AD (const std::vector< const VariableBase * > &us)
void request_AD (const VariableBase &u1, const VariableBase &u2)
void request_AD (const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
bool is_leaf (const DependencyResolver< Model, VariableName > &) const
const VariableBaseprovider (const DependencyResolver< Model, VariableName > &) const
 Get the provider in the dependency graph.
const DerivContainertotal_derivatives (const DependencyResolver< Model, VariableName > &) const
 Get total derivatives with respect to leaf variables.
const SecDerivContainertotal_second_derivatives (const DependencyResolver< Model, VariableName > &) const
 Get total second derivatives with respect to leaf variables.
void clear_chain_rule_cache (const DependencyResolver< Model, VariableName > &) const
 Clear chain rule cache.

Protected Attributes

Variable< T > * _ref
 The variable referenced by this (nullptr if this is a storing variable).
_value
 Variable value (undefined if this is a referencing variable).
Protected Attributes inherited from VariableBase
const VariableName _name = {}
 Name of the variable.
Model *const _owner = nullptr
 The model which declared this variable.
TensorShape _cached_intmd_sizes = {}
const TensorShape _base_sizes = {}
 Base shape of the variable.
bool _mutable = false
 When referenced by another variable, whether to allow the referencing variable to mutate my value.

Additional Inherited Members

Public Types inherited from VariableBase
using DerivTuple = std::tuple<Derivative<1>, const VariableBase *>
using DerivContainer = std::vector<DerivTuple>
using SecDerivTuple = std::tuple<Derivative<2>, const VariableBase *, const VariableBase *>
using SecDerivContainer = std::vector<SecDerivTuple>

Constructor & Destructor Documentation

◆ Variable()

template<typename T>
Variable ( VariableName name_in,
Model * owner )
inline

Member Function Documentation

◆ assign()

template<typename T>
void assign ( const Tensor & val,
std::optional< TracerPrivilege > key = std::nullopt )
overridevirtual

Assignment operator (with TracerPrivilege).

Implements VariableBase.

◆ clear()

template<typename T>
void clear ( )
overridevirtual

Clear the variable value and derivatives.

Reimplemented from VariableBase.

◆ clone()

template<typename T>
std::unique_ptr< VariableBase > clone ( const VariableName & name = {},
Model * owner = nullptr ) const
overridevirtual

Clone this variable.

Implements VariableBase.

◆ defined()

template<typename T>
bool defined ( ) const
overridevirtual

Defined

Implements VariableBase.

◆ device()

template<typename T>
Device device ( ) const
overridevirtual

Device.

Implements VariableBase.

◆ direct_ref() [1/2]

template<typename T>
const VariableBase * direct_ref ( ) const
inlineoverridevirtual

Get the direct referencing variable (returns nullptr if this is a storing variable).

Implements VariableBase.

◆ direct_ref() [2/2]

template<typename T>
VariableBase * direct_ref ( )
inlineoverridevirtual

Implements VariableBase.

◆ dynamic_sizes()

template<typename T>
const TraceableTensorShape & dynamic_sizes ( ) const
overridevirtual

Implements VariableBase.

◆ operator()()

template<typename T>
const T & operator() ( ) const
inline

Variable value.

◆ operator-()

template<typename T>
T operator- ( ) const
inline

Negation.

◆ operator=()

template<typename T>
void operator= ( const Tensor & val)
overridevirtual

Assignment operator.

Implements VariableBase.

◆ options()

template<typename T>
TensorOptions options ( ) const
overridevirtual

Tensor options.

Implements VariableBase.

◆ owning()

template<typename T>
bool owning ( ) const
inlineoverridevirtual

Check if this is an owning variable.

Implements VariableBase.

◆ ref() [1/3]

template<typename T>
const VariableBase * ref ( ) const
inlineoverridevirtual

Get the referencing variable (returns this if this is a storing variable).

Implements VariableBase.

◆ ref() [2/3]

template<typename T>
VariableBase * ref ( )
inlineoverridevirtual

Implements VariableBase.

◆ ref() [3/3]

template<typename T>
void ref ( VariableBase & other)
overridevirtual

Reference another variable.

Implements VariableBase.

◆ requires_grad_()

template<typename T>
void requires_grad_ ( bool req = true)
overridevirtual

Mark this variable as a leaf variable in tracing function graph for AD.

Implements VariableBase.

◆ scalar_type()

template<typename T>
Dtype scalar_type ( ) const
overridevirtual

Scalar type.

Implements VariableBase.

◆ tensor()

template<typename T>
Tensor tensor ( ) const
overridevirtual

Get the variable value cast to Tensor.

Implements VariableBase.

◆ type()

template<typename T>
TensorType type ( ) const
overridevirtual

Variable tensor type.

Implements VariableBase.

◆ zero()

template<typename T>
void zero ( const TensorOptions & options)
overridevirtual

Set the variable value to zero.

Implements VariableBase.

Member Data Documentation

◆ _ref

template<typename T>
Variable<T>* _ref
protected

The variable referenced by this (nullptr if this is a storing variable).

◆ _value

template<typename T>
T _value
protected

Variable value (undefined if this is a referencing variable).