NEML2 2.0.0
Loading...
Searching...
No Matches
Variable< T > Class Template Reference

Concrete definition of a variable. More...

Detailed Description

template<typename T>
class neml2::Variable< T >

Concrete definition of a variable.

#include <Variable.h>

Inheritance diagram for Variable< T >:

Public Member Functions

 Variable (VariableName name_in, Model *owner, TensorShapeRef dep_intmd_dims={})
 
TensorType type () const override
 Variable tensor type.
 
const TraceableTensorShapedynamic_sizes () const override
 
std::unique_ptr< VariableBaseclone (const VariableName &name={}, Model *owner=nullptr) const override
 Clone this variable.
 
void ref (const VariableBase &var, bool ref_is_mutable=false) override
 Reference another variable.
 
const VariableBaseref () const override
 Get the referencing variable (returns this if this is a storing variable)
 
bool owning () const override
 Check if this is an owning variable.
 
void zero (const TensorOptions &options) override
 Set the variable value to zero.
 
void set (const Tensor &val, std::optional< TracerPrivilege > key) override
 Set the variable value from a Tensor in assembly format.
 
Tensor get () const override
 Get the variable value in assembly format.
 
Tensor tensor () const override
 Get the variable value cast to Tensor.
 
void requires_grad_ (bool req=true) override
 Mark this variable as a leaf variable in tracing function graph for AD.
 
void operator= (const Tensor &val) override
 Assignment operator.
 
const T & operator() () const
 Variable value.
 
operator- () const
 Negation.
 
void clear () override
 Clear the variable value and derivatives.
 
Tensor information
bool defined () const override
 
TensorOptions options () const override
 Tensor options.
 
Dtype scalar_type () const override
 Scalar type.
 
Device device () const override
 Device.
 
- Public Member Functions inherited from VariableBase
 VariableBase ()=default
 
 VariableBase (const VariableBase &)=delete
 
 VariableBase (VariableBase &&)=delete
 
VariableBaseoperator= (const VariableBase &)=delete
 
VariableBaseoperator= (VariableBase &&)=delete
 
virtual ~VariableBase ()=default
 
 VariableBase (VariableName name_in, Model *owner, TensorShapeRef base_shape, TensorShapeRef dep_intmd_dims)
 
const VariableNamename () const
 Name of this variable.
 
void set_intmd_sizes (TensorShapeRef shape)
 Set the intermediate shape.
 
ArrayRef< Sizedep_intmd_dims () const
 Get dependent intermediate dimensions for derivative calculation.
 
Tensor zeros (const TensorOptions &options) const
 Make zeros tensor with the shape of this variable.
 
bool requires_grad () const
 Check if this variable is part of the AD function graph.
 
bool has_derivative (const VariableName &vname) const
 Whether the variable has non-zero derivative with respect to another variable.
 
bool has_derivative (const VariableName &v1name, const VariableName &v2name) const
 Whether the variable has non-zero second derivative with respect to another variable.
 
Derivative< 1 > & d (const VariableBase &var, ArrayRef< Size > dep_dims={})
 Wrapper for assigning partial derivative.
 
const Derivative< 1 > & d (const VariableBase &var) const
 
Derivative< 2 > & d2 (const VariableBase &var1, const VariableBase &var2, ArrayRef< Size > dep_dims={})
 Wrapper for assigning second partial derivative.
 
const Derivative< 2 > & d2 (const VariableBase &var1, const VariableBase &var2) const
 
const std::vector< Derivative< 1 > > & derivatives () const
 Partial derivatives.
 
std::vector< Derivative< 1 > > & derivatives ()
 
const std::vector< Derivative< 2 > > & second_derivatives () const
 Partial second derivatives.
 
std::vector< Derivative< 2 > > & second_derivatives ()
 
void clear_derivatives ()
 Clear only the derivatives.
 
void apply_chain_rule (const DependencyResolver< Model, VariableName > &)
 Apply first order chain rule.
 
void apply_second_order_chain_rule (const DependencyResolver< Model, VariableName > &)
 Apply second order chain rule.
 
const Modelowner () const
 
Modelowner ()
 
bool is_state () const
 
bool is_old_state () const
 
bool is_force () const
 
bool is_old_force () const
 
bool is_residual () const
 
bool is_parameter () const
 
bool is_solve_dependent () const
 
bool is_dependent () const
 Check if the derivative with respect to this variable should be evaluated.
 
Size dim () const
 
Size batch_dim () const
 
Size base_dim () const
 
Size dynamic_dim () const
 
Size static_dim () const
 
Size intmd_dim () const
 
TensorShapeRef sizes () const
 
TraceableTensorShape batch_sizes () const
 
TensorShapeRef base_sizes () const
 
TensorShapeRef static_sizes () const
 
TensorShapeRef intmd_sizes () const
 
Size size (Size i) const
 
TraceableSize batch_size (Size i) const
 
Size base_size (Size i) const
 
const TraceableSizedynamic_size (Size i) const
 
Size static_size (Size i) const
 
Size intmd_size (Size i) const
 
void request_AD (const VariableBase &u)
 
void request_AD (const std::vector< const VariableBase * > &us)
 
void request_AD (const VariableBase &u1, const VariableBase &u2)
 
void request_AD (const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
 

Protected Attributes

const Variable< T > * _ref
 The variable referenced by this (nullptr if this is a storing variable)
 
bool _ref_is_mutable = false
 Whether mutating the referenced variable is allowed.
 
_value
 Variable value (undefined if this is a referencing variable)
 

Additional Inherited Members

- Public Attributes inherited from VariableBase
const VariableName _name = {}
 Name of the variable.
 
Model *const _owner = nullptr
 The model which declared this variable.
 
TensorShape _cached_intmd_sizes = {}
 
const TensorShape _base_sizes = {}
 Base shape of the variable.
 
const TensorShape _dep_intmd_dims = {}
 Dependent intermediate dimensions for derivative calculation.
 

Constructor & Destructor Documentation

◆ Variable()

template<typename T >
Variable ( VariableName name_in,
Model * owner,
TensorShapeRef dep_intmd_dims = {} )
inline

Member Function Documentation

◆ clear()

template<typename T >
void clear ( )
overridevirtual

Clear the variable value and derivatives.

Reimplemented from VariableBase.

◆ clone()

template<typename T >
std::unique_ptr< VariableBase > clone ( const VariableName & name = {},
Model * owner = nullptr ) const
overridevirtual

Clone this variable.

Implements VariableBase.

◆ defined()

template<typename T >
bool defined ( ) const
overridevirtual

Defined

Implements VariableBase.

◆ device()

template<typename T >
Device device ( ) const
overridevirtual

Device.

Implements VariableBase.

◆ dynamic_sizes()

template<typename T >
const TraceableTensorShape & dynamic_sizes ( ) const
overridevirtual

Implements VariableBase.

◆ get()

template<typename T >
Tensor get ( ) const
overridevirtual

Get the variable value in assembly format.

Implements VariableBase.

◆ operator()()

template<typename T >
const T & operator() ( ) const
inline

Variable value.

◆ operator-()

template<typename T >
T operator- ( ) const
inline

Negation.

◆ operator=()

template<typename T >
void operator= ( const Tensor & val)
overridevirtual

Assignment operator.

Implements VariableBase.

◆ options()

template<typename T >
TensorOptions options ( ) const
overridevirtual

Tensor options.

Implements VariableBase.

◆ owning()

template<typename T >
bool owning ( ) const
inlineoverridevirtual

Check if this is an owning variable.

Implements VariableBase.

◆ ref() [1/2]

template<typename T >
const VariableBase * ref ( ) const
inlineoverridevirtual

Get the referencing variable (returns this if this is a storing variable)

Implements VariableBase.

◆ ref() [2/2]

template<typename T >
void ref ( const VariableBase & other,
bool ref_is_mutable = false )
overridevirtual

Reference another variable.

Implements VariableBase.

◆ requires_grad_()

template<typename T >
void requires_grad_ ( bool req = true)
overridevirtual

Mark this variable as a leaf variable in tracing function graph for AD.

Implements VariableBase.

◆ scalar_type()

template<typename T >
Dtype scalar_type ( ) const
overridevirtual

Scalar type.

Implements VariableBase.

◆ set()

template<typename T >
void set ( const Tensor & val,
std::optional< TracerPrivilege > key )
overridevirtual

Set the variable value from a Tensor in assembly format.

Implements VariableBase.

◆ tensor()

template<typename T >
Tensor tensor ( ) const
overridevirtual

Get the variable value cast to Tensor.

Implements VariableBase.

◆ type()

template<typename T >
TensorType type ( ) const
overridevirtual

Variable tensor type.

Implements VariableBase.

◆ zero()

template<typename T >
void zero ( const TensorOptions & options)
overridevirtual

Set the variable value to zero.

Implements VariableBase.

Member Data Documentation

◆ _ref

template<typename T >
const Variable<T>* _ref
protected

The variable referenced by this (nullptr if this is a storing variable)

◆ _ref_is_mutable

template<typename T >
bool _ref_is_mutable = false
protected

Whether mutating the referenced variable is allowed.

◆ _value

template<typename T >
T _value
protected

Variable value (undefined if this is a referencing variable)