NEML2 2.0.0
Loading...
Searching...
No Matches
Variable< T > Class Template Reference

Concrete definition of a variable. More...

Detailed Description

template<typename T>
class neml2::Variable< T >

Concrete definition of a variable.

#include <Variable.h>

Inheritance diagram for Variable< T >:

Public Member Functions

template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<Tensor, T2>>>
 Variable (VariableName name_in, Model *owner, TensorShapeRef list_shape)
 
template<typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<Tensor, T2>>>
 Variable (VariableName name_in, Model *owner, TensorShapeRef list_shape, TensorShapeRef base_shape)
 
TensorType type () const override
 Variable tensor type.
 
TensorShapeRef base_sizes () const override
 Return the base shape.
 
std::unique_ptr< VariableBaseclone (const VariableName &name={}, Model *owner=nullptr) const override
 Clone this variable.
 
void ref (const VariableBase &var, bool ref_is_mutable=false) override
 Reference another variable.
 
const VariableBaseref () const override
 Get the referencing variable (returns this if this is a storing variable)
 
bool owning () const override
 Check if this is an owning variable.
 
void zero (const torch::TensorOptions &options) override
 Set the variable value to zero.
 
void set (const Tensor &val) override
 Set the variable value.
 
Tensor get () const override
 Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
 
Tensor tensor () const override
 Get the variable value.
 
void requires_grad_ (bool req=true) override
 Mark this variable as a leaf variable in tracing function graph for AD.
 
void operator= (const Tensor &val) override
 Assignment operator.
 
const T & value () const
 Variable value.
 
operator- () const
 Negation.
 
 operator T () const
 Convert to the underlying tensor type.
 
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, Tensor>>>
 operator Tensor () const
 Convert to Tensor.
 
void clear () override
 Clear the variable value and derivatives.
 
- Public Member Functions inherited from VariableBase
 VariableBase ()=default
 
 VariableBase (const VariableBase &)=delete
 
 VariableBase (VariableBase &&)=delete
 
VariableBaseoperator= (const VariableBase &)=delete
 
VariableBaseoperator= (VariableBase &&)=delete
 
virtual ~VariableBase ()=default
 
 VariableBase (VariableName name_in, Model *owner, TensorShapeRef list_shape)
 
const VariableNamename () const
 Name of this variable.
 
bool requires_grad () const
 Check if this variable is part of the AD function graph.
 
Derivative d (const VariableBase &var)
 Wrapper for assigning partial derivative.
 
Derivative d (const VariableBase &var1, const VariableBase &var2)
 Wrapper for assigning second partial derivative.
 
const ValueMapderivatives () const
 Partial derivatives.
 
ValueMapderivatives ()
 
const DerivMapsecond_derivatives () const
 Partial second derivatives.
 
DerivMapsecond_derivatives ()
 
void apply_chain_rule (const DependencyResolver< Model, VariableName > &)
 Apply first order chain rule.
 
void apply_second_order_chain_rule (const DependencyResolver< Model, VariableName > &)
 Apply second order chain rule.
 
const Modelowner () const
 
Modelowner ()
 
bool is_state () const
 
bool is_old_state () const
 
bool is_force () const
 
bool is_old_force () const
 
bool is_residual () const
 
bool is_parameter () const
 
bool is_solve_dependent () const
 
bool is_dependent () const
 Check if the derivative with respect to this variable should be evaluated.
 
torch::TensorOptions options () const
 
torch::Dtype scalar_type () const
 Scalar type.
 
torch::Device device () const
 Device.
 
Size dim () const
 Number of tensor dimensions.
 
TensorShapeRef sizes () const
 Tensor shape.
 
Size size (Size dim) const
 Size of a dimension.
 
bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
Size list_dim () const
 Return the number of list dimensions.
 
Size base_dim () const
 Return the number of base dimensions.
 
TraceableTensorShape batch_sizes () const
 Return the batch shape.
 
TensorShapeRef list_sizes () const
 Return the list shape.
 
TraceableSize batch_size (Size dim) const
 Return the size of a batch axis.
 
Size base_size (Size dim) const
 Return the size of a base axis.
 
Size list_size (Size dim) const
 Return the size of a list axis.
 
Size base_storage () const
 Base storage of the variable.
 
Size assembly_storage () const
 Assembly storage of the variable.
 
void request_AD (const VariableBase &u)
 
void request_AD (const std::vector< const VariableBase * > &us)
 
void request_AD (const VariableBase &u1, const VariableBase &u2)
 
void request_AD (const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
 

Protected Attributes

const TensorShape _base_sizes
 Base shape of the variable.
 
const Variable< T > * _ref
 The variable referenced by this (nullptr if this is a storing variable)
 
bool _ref_is_mutable
 Whether mutating the referenced variable is allowed.
 
_value
 Variable value (undefined if this is a referencing variable)
 

Additional Inherited Members

- Public Attributes inherited from VariableBase
const VariableName _name = {}
 Name of the variable.
 
Model *const _owner = nullptr
 The model which declared this variable.
 

Constructor & Destructor Documentation

◆ Variable() [1/2]

template<typename T >
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<Tensor, T2>>>
Variable ( VariableName name_in,
Model * owner,
TensorShapeRef list_shape )
inline

◆ Variable() [2/2]

template<typename T >
template<typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<Tensor, T2>>>
Variable ( VariableName name_in,
Model * owner,
TensorShapeRef list_shape,
TensorShapeRef base_shape )
inline

Member Function Documentation

◆ base_sizes()

template<typename T >
TensorShapeRef base_sizes ( ) const
inlineoverridevirtual

Return the base shape.

Implements VariableBase.

◆ clear()

template<typename T >
void clear ( )
overridevirtual

Clear the variable value and derivatives.

Reimplemented from VariableBase.

◆ clone()

template<typename T >
std::unique_ptr< VariableBase > clone ( const VariableName & name = {},
Model * owner = nullptr ) const
overridevirtual

Clone this variable.

Implements VariableBase.

◆ get()

template<typename T >
Tensor get ( ) const
inlineoverridevirtual

Get the variable value (with flattened base dimensions, i.e., for assembly purposes)

Implements VariableBase.

◆ operator T()

template<typename T >
operator T ( ) const
inline

Convert to the underlying tensor type.

◆ operator Tensor()

template<typename T >
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, Tensor>>>
operator Tensor ( ) const
inline

Convert to Tensor.

◆ operator-()

template<typename T >
T operator- ( ) const
inline

Negation.

◆ operator=()

template<typename T >
void operator= ( const Tensor & val)
overridevirtual

Assignment operator.

Implements VariableBase.

◆ owning()

template<typename T >
bool owning ( ) const
inlineoverridevirtual

Check if this is an owning variable.

Implements VariableBase.

◆ ref() [1/2]

template<typename T >
const VariableBase * ref ( ) const
inlineoverridevirtual

Get the referencing variable (returns this if this is a storing variable)

Implements VariableBase.

◆ ref() [2/2]

template<typename T >
void ref ( const VariableBase & other,
bool ref_is_mutable = false )
overridevirtual

Reference another variable.

Implements VariableBase.

◆ requires_grad_()

template<typename T >
void requires_grad_ ( bool req = true)
overridevirtual

Mark this variable as a leaf variable in tracing function graph for AD.

Implements VariableBase.

◆ set()

template<typename T >
void set ( const Tensor & val)
overridevirtual

Set the variable value.

Implements VariableBase.

◆ tensor()

template<typename T >
Tensor tensor ( ) const
overridevirtual

Get the variable value.

Implements VariableBase.

◆ type()

template<typename T >
TensorType type ( ) const
overridevirtual

Variable tensor type.

Implements VariableBase.

◆ value()

template<typename T >
const T & value ( ) const
inline

Variable value.

◆ zero()

template<typename T >
void zero ( const torch::TensorOptions & options)
overridevirtual

Set the variable value to zero.

Implements VariableBase.

Member Data Documentation

◆ _base_sizes

template<typename T >
const TensorShape _base_sizes
protected

Base shape of the variable.

◆ _ref

template<typename T >
const Variable<T>* _ref
protected

The variable referenced by this (nullptr if this is a storing variable)

◆ _ref_is_mutable

template<typename T >
bool _ref_is_mutable
protected

Whether mutating the referenced variable is allowed.

◆ _value

template<typename T >
T _value
protected

Variable value (undefined if this is a referencing variable)