NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
VariableBase Class Referenceabstract

Base class of variable. More...

Detailed Description

Base class of variable.

Specific implementations are defined by the derived class Variable<T> where we rely on polymorphism so that we can store different types of variables in the same container.

#include <Variable.h>

Inheritance diagram for VariableBase:

Public Member Functions

 VariableBase ()=default
 
 VariableBase (const VariableBase &)=delete
 
 VariableBase (VariableBase &&)=delete
 
VariableBaseoperator= (const VariableBase &)=delete
 
VariableBaseoperator= (VariableBase &&)=delete
 
virtual ~VariableBase ()=default
 
 VariableBase (VariableName name_in, Model *owner, TensorShapeRef list_shape)
 
const VariableNamename () const
 Name of this variable.
 
virtual TensorType type () const =0
 Variable tensor type.
 
virtual std::unique_ptr< VariableBaseclone (const VariableName &name={}, Model *owner=nullptr) const =0
 Clone this variable.
 
virtual void ref (const VariableBase &other, bool ref_is_mutable=false)=0
 Reference another variable.
 
virtual const VariableBaseref () const =0
 Get the referencing variable (returns this if this is a storing variable)
 
virtual bool owning () const =0
 Check if this is an owning variable.
 
virtual void zero (const TensorOptions &options)=0
 Set the variable value to zero.
 
virtual void set (const Tensor &val)=0
 Set the variable value.
 
virtual void set (const ATensor &val, bool force=false)=0
 
virtual Tensor get () const =0
 Get the variable value (with flattened base dimensions, i.e., for assembly purposes)
 
virtual Tensor tensor () const =0
 Get the variable value.
 
bool requires_grad () const
 Check if this variable is part of the AD function graph.
 
virtual void requires_grad_ (bool req=true)=0
 Mark this variable as a leaf variable in tracing function graph for AD.
 
virtual void operator= (const Tensor &val)=0
 Assignment operator.
 
Derivative d (const VariableBase &var)
 Wrapper for assigning partial derivative.
 
Derivative d (const VariableBase &var1, const VariableBase &var2)
 Wrapper for assigning second partial derivative.
 
const ValueMapderivatives () const
 Partial derivatives.
 
ValueMapderivatives ()
 
const DerivMapsecond_derivatives () const
 Partial second derivatives.
 
DerivMapsecond_derivatives ()
 
virtual void clear ()
 Clear the variable value and derivatives.
 
void apply_chain_rule (const DependencyResolver< Model, VariableName > &)
 Apply first order chain rule.
 
void apply_second_order_chain_rule (const DependencyResolver< Model, VariableName > &)
 Apply second order chain rule.
 
const Modelowner () const
 
Modelowner ()
 
Subaxis
bool is_state () const
 
bool is_old_state () const
 
bool is_force () const
 
bool is_old_force () const
 
bool is_residual () const
 
bool is_parameter () const
 
bool is_solve_dependent () const
 
bool is_dependent () const
 Check if the derivative with respect to this variable should be evaluated.
 
Tensor information
TensorOptions options () const
 
Dtype scalar_type () const
 Scalar type.
 
Device device () const
 Device.
 
Size dim () const
 Number of tensor dimensions.
 
TensorShapeRef sizes () const
 Tensor shape.
 
Size size (Size dim) const
 Size of a dimension.
 
bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
Size list_dim () const
 Return the number of list dimensions.
 
Size base_dim () const
 Return the number of base dimensions.
 
TraceableTensorShape batch_sizes () const
 Return the batch shape.
 
TensorShapeRef list_sizes () const
 Return the list shape.
 
virtual TensorShapeRef base_sizes () const =0
 Return the base shape.
 
TraceableSize batch_size (Size dim) const
 Return the size of a batch axis.
 
Size base_size (Size dim) const
 Return the size of a base axis.
 
Size list_size (Size dim) const
 Return the size of a list axis.
 
Size base_storage () const
 Base storage of the variable.
 
Size assembly_storage () const
 Assembly storage of the variable.
 
void request_AD (const VariableBase &u)
 
void request_AD (const std::vector< const VariableBase * > &us)
 
void request_AD (const VariableBase &u1, const VariableBase &u2)
 
void request_AD (const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
 

Public Attributes

const VariableName _name = {}
 Name of the variable.
 
Model *const _owner = nullptr
 The model which declared this variable.
 

Constructor & Destructor Documentation

◆ VariableBase() [1/4]

VariableBase ( )
default

◆ VariableBase() [2/4]

VariableBase ( const VariableBase & )
delete

◆ VariableBase() [3/4]

VariableBase ( VariableBase && )
delete

◆ ~VariableBase()

virtual ~VariableBase ( )
virtualdefault

◆ VariableBase() [4/4]

VariableBase ( VariableName name_in,
Model * owner,
TensorShapeRef list_shape )

Member Function Documentation

◆ apply_chain_rule()

void apply_chain_rule ( const DependencyResolver< Model, VariableName > & dep)

Apply first order chain rule.

◆ apply_second_order_chain_rule()

void apply_second_order_chain_rule ( const DependencyResolver< Model, VariableName > & dep)

Apply second order chain rule.

◆ assembly_storage()

Size assembly_storage ( ) const

Assembly storage of the variable.

◆ base_dim()

Size base_dim ( ) const

Return the number of base dimensions.

◆ base_size()

Size base_size ( Size dim) const

Return the size of a base axis.

◆ base_sizes()

◆ base_storage()

Size base_storage ( ) const

Base storage of the variable.

◆ batch_dim()

Size batch_dim ( ) const

Return the number of batch dimensions.

◆ batch_size()

TraceableSize batch_size ( Size dim) const

Return the size of a batch axis.

◆ batch_sizes()

TraceableTensorShape batch_sizes ( ) const

Return the batch shape.

◆ batched()

bool batched ( ) const

Whether the tensor is batched.

◆ clear()

◆ clone()

virtual std::unique_ptr< VariableBase > clone ( const VariableName & name = {},
Model * owner = nullptr ) const
pure virtual

◆ d() [1/2]

Derivative d ( const VariableBase & var)

Wrapper for assigning partial derivative.

◆ d() [2/2]

Derivative d ( const VariableBase & var1,
const VariableBase & var2 )

Wrapper for assigning second partial derivative.

◆ derivatives() [1/2]

ValueMap & derivatives ( )
inline

◆ derivatives() [2/2]

const ValueMap & derivatives ( ) const
inline

Partial derivatives.

◆ device()

Device device ( ) const

◆ dim()

Size dim ( ) const

Number of tensor dimensions.

◆ get()

virtual Tensor get ( ) const
pure virtual

Get the variable value (with flattened base dimensions, i.e., for assembly purposes)

Implemented in Variable< T >, Variable< neml2::R2 >, Variable< neml2::Rot >, Variable< neml2::Scalar >, Variable< neml2::SR2 >, Variable< neml2::SSR4 >, Variable< neml2::Vec >, and Variable< neml2::WR2 >.

◆ is_dependent()

bool is_dependent ( ) const

Check if the derivative with respect to this variable should be evaluated.

◆ is_force()

bool is_force ( ) const

◆ is_old_force()

bool is_old_force ( ) const

◆ is_old_state()

bool is_old_state ( ) const

◆ is_parameter()

bool is_parameter ( ) const

◆ is_residual()

bool is_residual ( ) const

◆ is_solve_dependent()

bool is_solve_dependent ( ) const

◆ is_state()

bool is_state ( ) const

◆ list_dim()

Size list_dim ( ) const

Return the number of list dimensions.

◆ list_size()

Size list_size ( Size dim) const

Return the size of a list axis.

◆ list_sizes()

TensorShapeRef list_sizes ( ) const

Return the list shape.

◆ name()

const VariableName & name ( ) const
inline

Name of this variable.

◆ operator=() [1/3]

◆ operator=() [2/3]

VariableBase & operator= ( const VariableBase & )
delete

◆ operator=() [3/3]

VariableBase & operator= ( VariableBase && )
delete

◆ options()

TensorOptions options ( ) const

Tensor options

◆ owner() [1/2]

Model & owner ( )

◆ owner() [2/2]

const Model & owner ( ) const

The Model who declared this variable

◆ owning()

virtual bool owning ( ) const
pure virtual

◆ ref() [1/2]

virtual const VariableBase * ref ( ) const
pure virtual

◆ ref() [2/2]

virtual void ref ( const VariableBase & other,
bool ref_is_mutable = false )
pure virtual

◆ request_AD() [1/4]

void request_AD ( const std::vector< const VariableBase * > & u1s,
const std::vector< const VariableBase * > & u2s )

◆ request_AD() [2/4]

void request_AD ( const std::vector< const VariableBase * > & us)

◆ request_AD() [3/4]

void request_AD ( const VariableBase & u)

Request to use AD to calculate the derivative of this variable with respect to another variable

◆ request_AD() [4/4]

void request_AD ( const VariableBase & u1,
const VariableBase & u2 )

Request to use AD to calculate the second derivative of this variable with respect to two other variables

◆ requires_grad()

bool requires_grad ( ) const

Check if this variable is part of the AD function graph.

◆ requires_grad_()

virtual void requires_grad_ ( bool req = true)
pure virtual

◆ scalar_type()

Dtype scalar_type ( ) const

Scalar type.

◆ second_derivatives() [1/2]

DerivMap & second_derivatives ( )
inline

◆ second_derivatives() [2/2]

const DerivMap & second_derivatives ( ) const
inline

Partial second derivatives.

◆ set() [1/2]

virtual void set ( const ATensor & val,
bool force = false )
pure virtual

Set the variable value from a ATensor (with inferred batch shape) If force is true, the value is set even if the variable is a reference

Implemented in Variable< T >, Variable< neml2::R2 >, Variable< neml2::Rot >, Variable< neml2::Scalar >, Variable< neml2::SR2 >, Variable< neml2::SSR4 >, Variable< neml2::Vec >, and Variable< neml2::WR2 >.

◆ set() [2/2]

◆ size()

Size size ( Size dim) const

Size of a dimension.

◆ sizes()

TensorShapeRef sizes ( ) const

Tensor shape.

◆ tensor()

◆ type()

◆ zero()

Member Data Documentation

◆ _name

const VariableName _name = {}

Name of the variable.

◆ _owner

Model* const _owner = nullptr

The model which declared this variable.