NEML2 2.1.0
Loading...
Searching...
No Matches

Base class of variable. More...

Detailed Description

Base class of variable.

Specific implementations are defined by the derived class Variable<T> where we rely on polymorphism so that we can store different types of variables in the same container.

#include <VariableBase.h>

Inheritance diagram for VariableBase:

Public Types

using DerivTuple = std::tuple<Derivative<1>, const VariableBase *>
using DerivContainer = std::vector<DerivTuple>
using SecDerivTuple = std::tuple<Derivative<2>, const VariableBase *, const VariableBase *>
using SecDerivContainer = std::vector<SecDerivTuple>

Public Member Functions

 VariableBase ()=default
 VariableBase (const VariableBase &)=delete
 VariableBase (VariableBase &&)=delete
VariableBaseoperator= (const VariableBase &)=delete
VariableBaseoperator= (VariableBase &&)=delete
virtual ~VariableBase ()
 VariableBase (VariableName name_in, Model *owner, TensorShapeRef base_shape)
 The canonical constructor.
const VariableNamename () const
 Name of this variable.
virtual TensorType type () const =0
 Variable tensor type.
virtual std::unique_ptr< VariableBaseclone (const VariableName &name={}, Model *owner=nullptr) const =0
 Clone this variable.
virtual void ref (VariableBase &other)=0
 Reference another variable.
virtual const VariableBaseref () const =0
 Get the referencing variable (returns this if this is a storing variable).
virtual VariableBaseref ()=0
virtual const VariableBasedirect_ref () const =0
 Get the direct referencing variable (returns nullptr if this is a storing variable).
virtual VariableBasedirect_ref ()=0
virtual bool owning () const =0
 Check if this is an owning variable.
bool is_mutable () const
 Whether this variable is mutable when it is referenced by another variable.
void set_mutable (bool m)
 Allow/disable mutation of this variable when it is referenced by another variable.
Tensor zeros (const TensorOptions &options) const
 Make zeros tensor with the shape of this variable.
virtual void zero (const TensorOptions &options)=0
 Set the variable value to zero.
virtual Tensor tensor () const =0
 Get the variable value cast to Tensor.
bool requires_grad () const
 Check if this variable is part of the AD function graph.
virtual void requires_grad_ (bool req=true)=0
 Mark this variable as a leaf variable in tracing function graph for AD.
virtual void assign (const Tensor &val, std::optional< TracerPrivilege > key=std::nullopt)=0
 Assignment operator (with TracerPrivilege).
virtual void operator= (const Tensor &val)=0
 Assignment operator.
bool has_derivative (const VariableName &vname) const
 Whether the variable has non-zero derivative with respect to another variable.
bool has_derivative (const VariableName &v1name, const VariableName &v2name) const
 Whether the variable has non-zero second derivative with respect to another variable.
Derivative< 1 > & d (const VariableBase &arg, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg_intrsc_intmd_dim=0)
 Wrapper for assigning partial derivative.
const Derivative< 1 > & d (const VariableBase &arg) const
Derivative< 2 > & d2 (const VariableBase &arg1, const VariableBase &arg2, std::size_t deriv_intrsc_intmd_dim=0, std::size_t var_intrsc_intmd_dim=0, std::size_t arg1_intrsc_intmd_dim=0, std::size_t arg2_intrsc_intmd_dim=0)
 Wrapper for assigning second partial derivative.
const Derivative< 2 > & d2 (const VariableBase &arg1, const VariableBase &arg2) const
const DerivContainerderivatives () const
 Partial derivatives.
DerivContainerderivatives ()
const SecDerivContainersecond_derivatives () const
 Partial second derivatives.
SecDerivContainersecond_derivatives ()
virtual void clear ()
 Clear the variable value and derivatives.
void clear_derivatives ()
 Clear only the derivatives.
const Modelowner () const
Modelowner ()
Subaxis
bool is_state () const
bool is_old_state () const
bool is_force () const
bool is_old_force () const
bool is_residual () const
bool is_parameter () const
bool is_solve_dependent () const
bool is_dependent () const
 Check if the derivative with respect to this variable should be evaluated.
Tensor information
virtual bool defined () const =0
virtual TensorOptions options () const =0
 Tensor options.
virtual Dtype scalar_type () const =0
 Scalar type.
virtual Device device () const =0
 Device.
Size dim () const
Size batch_dim () const
Size base_dim () const
Size dynamic_dim () const
Size static_dim () const
Size intmd_dim () const
TensorShapeRef sizes () const
TraceableTensorShape batch_sizes () const
TensorShapeRef base_sizes () const
virtual const TraceableTensorShapedynamic_sizes () const =0
TensorShapeRef static_sizes () const
TensorShapeRef intmd_sizes () const
Size size (Size i) const
TraceableSize batch_size (Size i) const
Size base_size (Size i) const
const TraceableSizedynamic_size (Size i) const
Size static_size (Size i) const
Size intmd_size (Size i) const
void request_AD (const VariableBase &u)
void request_AD (const std::vector< const VariableBase * > &us)
void request_AD (const VariableBase &u1, const VariableBase &u2)
void request_AD (const std::vector< const VariableBase * > &u1s, const std::vector< const VariableBase * > &u2s)
bool is_leaf (const DependencyResolver< Model, VariableName > &) const
const VariableBaseprovider (const DependencyResolver< Model, VariableName > &) const
 Get the provider in the dependency graph.
const DerivContainertotal_derivatives (const DependencyResolver< Model, VariableName > &) const
 Get total derivatives with respect to leaf variables.
const SecDerivContainertotal_second_derivatives (const DependencyResolver< Model, VariableName > &) const
 Get total second derivatives with respect to leaf variables.
void clear_chain_rule_cache (const DependencyResolver< Model, VariableName > &) const
 Clear chain rule cache.

Protected Attributes

const VariableName _name = {}
 Name of the variable.
Model *const _owner = nullptr
 The model which declared this variable.
TensorShape _cached_intmd_sizes = {}
const TensorShape _base_sizes = {}
 Base shape of the variable.
bool _mutable = false
 When referenced by another variable, whether to allow the referencing variable to mutate my value.

Member Typedef Documentation

◆ DerivContainer

using DerivContainer = std::vector<DerivTuple>

◆ DerivTuple

using DerivTuple = std::tuple<Derivative<1>, const VariableBase *>

◆ SecDerivContainer

using SecDerivContainer = std::vector<SecDerivTuple>

◆ SecDerivTuple

using SecDerivTuple = std::tuple<Derivative<2>, const VariableBase *, const VariableBase *>

Constructor & Destructor Documentation

◆ VariableBase() [1/4]

VariableBase ( )
default

◆ VariableBase() [2/4]

VariableBase ( const VariableBase & )
delete

◆ VariableBase() [3/4]

VariableBase ( VariableBase && )
delete

◆ ~VariableBase()

virtual ~VariableBase ( )
virtual

◆ VariableBase() [4/4]

VariableBase ( VariableName name_in,
Model * owner,
TensorShapeRef base_shape )

The canonical constructor.

Parameters
name_inVariable name
ownerModel who declared this variable
base_shapeBase shape of this variable

Member Function Documentation

◆ assign()

virtual void assign ( const Tensor & val,
std::optional< TracerPrivilege > key = std::nullopt )
pure virtual

◆ base_dim()

Size base_dim ( ) const

◆ base_size()

Size base_size ( Size i) const

◆ base_sizes()

TensorShapeRef base_sizes ( ) const

◆ batch_dim()

Size batch_dim ( ) const

◆ batch_size()

TraceableSize batch_size ( Size i) const

◆ batch_sizes()

TraceableTensorShape batch_sizes ( ) const

◆ clear()

virtual void clear ( )
virtual

◆ clear_chain_rule_cache()

void clear_chain_rule_cache ( const DependencyResolver< Model, VariableName > & ) const

Clear chain rule cache.

◆ clear_derivatives()

void clear_derivatives ( )

Clear only the derivatives.

◆ clone()

virtual std::unique_ptr< VariableBase > clone ( const VariableName & name = {},
Model * owner = nullptr ) const
pure virtual

◆ d() [1/2]

const Derivative< 1 > & d ( const VariableBase & arg) const

◆ d() [2/2]

Derivative< 1 > & d ( const VariableBase & arg,
std::size_t deriv_intrsc_intmd_dim = 0,
std::size_t var_intrsc_intmd_dim = 0,
std::size_t arg_intrsc_intmd_dim = 0 )

Wrapper for assigning partial derivative.

◆ d2() [1/2]

const Derivative< 2 > & d2 ( const VariableBase & arg1,
const VariableBase & arg2 ) const

◆ d2() [2/2]

Derivative< 2 > & d2 ( const VariableBase & arg1,
const VariableBase & arg2,
std::size_t deriv_intrsc_intmd_dim = 0,
std::size_t var_intrsc_intmd_dim = 0,
std::size_t arg1_intrsc_intmd_dim = 0,
std::size_t arg2_intrsc_intmd_dim = 0 )

Wrapper for assigning second partial derivative.

◆ defined()

◆ derivatives() [1/2]

DerivContainer & derivatives ( )
inline

◆ derivatives() [2/2]

const DerivContainer & derivatives ( ) const
inline

Partial derivatives.

◆ device()

◆ dim()

Size dim ( ) const
Returns
the number of dimensions

◆ direct_ref() [1/2]

virtual const VariableBase * direct_ref ( ) const
pure virtual

Get the direct referencing variable (returns nullptr if this is a storing variable).

Implemented in Variable< T >, Variable< neml2::R2 >, Variable< neml2::Rot >, Variable< neml2::Scalar >, Variable< neml2::SR2 >, Variable< neml2::SSR4 >, and Variable< neml2::WR2 >.

◆ direct_ref() [2/2]

virtual VariableBase * direct_ref ( )
pure virtual

Implemented in Variable< T >.

◆ dynamic_dim()

Size dynamic_dim ( ) const

◆ dynamic_size()

const TraceableSize & dynamic_size ( Size i) const

◆ dynamic_sizes()

◆ has_derivative() [1/2]

bool has_derivative ( const VariableName & v1name,
const VariableName & v2name ) const

Whether the variable has non-zero second derivative with respect to another variable.

◆ has_derivative() [2/2]

bool has_derivative ( const VariableName & vname) const

Whether the variable has non-zero derivative with respect to another variable.

◆ intmd_dim()

Size intmd_dim ( ) const

◆ intmd_size()

Size intmd_size ( Size i) const

◆ intmd_sizes()

TensorShapeRef intmd_sizes ( ) const

◆ is_dependent()

bool is_dependent ( ) const

Check if the derivative with respect to this variable should be evaluated.

◆ is_force()

bool is_force ( ) const

◆ is_leaf()

bool is_leaf ( const DependencyResolver< Model, VariableName > & ) const

Whether this variable is a leaf variable in the dependency graph

◆ is_mutable()

bool is_mutable ( ) const

Whether this variable is mutable when it is referenced by another variable.

◆ is_old_force()

bool is_old_force ( ) const

◆ is_old_state()

bool is_old_state ( ) const

◆ is_parameter()

bool is_parameter ( ) const

◆ is_residual()

bool is_residual ( ) const

◆ is_solve_dependent()

bool is_solve_dependent ( ) const

◆ is_state()

bool is_state ( ) const

◆ name()

const VariableName & name ( ) const
inline

Name of this variable.

◆ operator=() [1/3]

virtual void operator= ( const Tensor & val)
pure virtual

◆ operator=() [2/3]

VariableBase & operator= ( const VariableBase & )
delete

◆ operator=() [3/3]

VariableBase & operator= ( VariableBase && )
delete

◆ options()

◆ owner() [1/2]

Model & owner ( )

◆ owner() [2/2]

const Model & owner ( ) const

The Model who declared this variable

◆ owning()

virtual bool owning ( ) const
pure virtual

◆ provider()

const VariableBase & provider ( const DependencyResolver< Model, VariableName > & ) const

Get the provider in the dependency graph.

◆ ref() [1/3]

virtual const VariableBase * ref ( ) const
pure virtual

Get the referencing variable (returns this if this is a storing variable).

Implemented in Variable< T >.

◆ ref() [2/3]

virtual VariableBase * ref ( )
pure virtual

Implemented in Variable< T >.

◆ ref() [3/3]

◆ request_AD() [1/4]

void request_AD ( const std::vector< const VariableBase * > & u1s,
const std::vector< const VariableBase * > & u2s )

◆ request_AD() [2/4]

void request_AD ( const std::vector< const VariableBase * > & us)

◆ request_AD() [3/4]

void request_AD ( const VariableBase & u)

Request to use AD to calculate the derivative of this variable with respect to another variable

◆ request_AD() [4/4]

void request_AD ( const VariableBase & u1,
const VariableBase & u2 )

Request to use AD to calculate the second derivative of this variable with respect to two other variables

◆ requires_grad()

bool requires_grad ( ) const

Check if this variable is part of the AD function graph.

◆ requires_grad_()

virtual void requires_grad_ ( bool req = true)
pure virtual

Mark this variable as a leaf variable in tracing function graph for AD.

Implemented in Variable< T >, Variable< neml2::R2 >, Variable< neml2::Rot >, Variable< neml2::Scalar >, Variable< neml2::SR2 >, Variable< neml2::SSR4 >, and Variable< neml2::WR2 >.

◆ scalar_type()

◆ second_derivatives() [1/2]

SecDerivContainer & second_derivatives ( )
inline

◆ second_derivatives() [2/2]

const SecDerivContainer & second_derivatives ( ) const
inline

Partial second derivatives.

◆ set_mutable()

void set_mutable ( bool m)

Allow/disable mutation of this variable when it is referenced by another variable.

◆ size()

Size size ( Size i) const
Returns
the size of dimension i

◆ sizes()

TensorShapeRef sizes ( ) const
Returns
the tensor shape

◆ static_dim()

Size static_dim ( ) const

◆ static_size()

Size static_size ( Size i) const

◆ static_sizes()

TensorShapeRef static_sizes ( ) const

◆ tensor()

◆ total_derivatives()

const DerivContainer & total_derivatives ( const DependencyResolver< Model, VariableName > & ) const

Get total derivatives with respect to leaf variables.

◆ total_second_derivatives()

const SecDerivContainer & total_second_derivatives ( const DependencyResolver< Model, VariableName > & ) const

Get total second derivatives with respect to leaf variables.

◆ type()

◆ zero()

virtual void zero ( const TensorOptions & options)
pure virtual

◆ zeros()

Tensor zeros ( const TensorOptions & options) const

Make zeros tensor with the shape of this variable.

Member Data Documentation

◆ _base_sizes

const TensorShape _base_sizes = {}
protected

Base shape of the variable.

◆ _cached_intmd_sizes

TensorShape _cached_intmd_sizes = {}
protected

Cached intermediate shape that this variable last saw

Note
: set() and operator=() are the only methods that cache this. clear() does not invalidate the cache.

◆ _mutable

bool _mutable = false
protected

When referenced by another variable, whether to allow the referencing variable to mutate my value.

◆ _name

const VariableName _name = {}
protected

Name of the variable.

◆ _owner

Model* const _owner = nullptr
protected

The model which declared this variable.