NEML2 2.0.0
Loading...
Searching...
No Matches
ParameterStore Class Reference

Interface for object which can store parameters. More...

Detailed Description

Interface for object which can store parameters.

#include <ParameterStore.h>

Inheritance diagram for ParameterStore:

Public Member Functions

 ParameterStore (OptionSet options, NEML2Object *object)
 
 ParameterStore (const ParameterStore &)=delete
 
 ParameterStore (ParameterStore &&)=delete
 
ParameterStoreoperator= (const ParameterStore &)=delete
 
ParameterStoreoperator= (ParameterStore &&)=delete
 
virtual ~ParameterStore ()=default
 
std::map< std::string, const VariableBase * > _nl_params
 Map from nonlinear parameter names to their corresponding variable views.
 
std::map< std::string, Model * > _nl_param_models
 Map from nonlinear parameter names to models which evaluate them.
 
const Storage< std::string, TensorValueBase > & named_parameters () const
 
Storage< std::string, TensorValueBase > & named_parameters ()
 
void set_parameter (const std::string &, const Tensor &)
 }@
 
void set_parameters (const std::map< std::string, Tensor > &)
 Set values for parameters.
 
TensorValueBaseget_parameter (const std::string &name)
 Get a writable reference of a parameter.
 
const TensorValueBaseget_parameter (const std::string &name) const
 Get a read-only reference of a parameter.
 
bool has_nl_param () const
 Whether this parameter store has any nonlinear parameter.
 
const VariableBasenl_param (const std::string &) const
 Query the existence of a nonlinear parameter.
 
virtual std::map< std::string, const VariableBase * > named_nonlinear_parameters (bool recursive=false) const
 Get all nonlinear parameters.
 
virtual std::map< std::string, Model * > named_nonlinear_parameter_models (bool recursive=false) const
 Get all nonlinear parameters' models.
 
virtual void send_parameters_to (const torch::TensorOptions &options)
 Send parameters to options.
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
const T & declare_parameter (const std::string &name, const T &rawval)
 Declare a parameter.
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
const T & declare_parameter (const std::string &name, const CrossRef< T > &crossref, bool allow_nonlinear)
 Declare a parameter.
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
const T & declare_parameter (const std::string &name, const std::string &input_option_name, bool allow_nonlinear=false)
 Declare a parameter.
 

Constructor & Destructor Documentation

◆ ParameterStore() [1/3]

ParameterStore ( OptionSet options,
NEML2Object * object )

◆ ParameterStore() [2/3]

◆ ParameterStore() [3/3]

◆ ~ParameterStore()

virtual ~ParameterStore ( )
virtualdefault

Member Function Documentation

◆ declare_parameter() [1/3]

template<typename T , typename >
const T & declare_parameter ( const std::string & name,
const CrossRef< T > & crossref,
bool allow_nonlinear )
protected

Declare a parameter.

Similar to the previous method, but additionally handles the resolution of cross-referenced parameters. Two attempts are made sequentially: first, the method tries to resolve CrossRef<T> into T directly; if that fails, the method tries to resolve CrossRef<T> into a nonlinear parameter where the raw string stored in the cross-ref is treated as the name of the model that defines the nonlinear parameter.

Template Parameters
TParameter type. See Statically shaped tensors for supported types.
Parameters
nameName of the model parameter.
crossrefThe cross-ref'ed "string" that defines the value of the model parameter.
allow_nonlinearWhether allows coupling with a nonlinear parameter
Returns
T The value of the registered model parameter.

◆ declare_parameter() [2/3]

template<typename T , typename >
const T & declare_parameter ( const std::string & name,
const std::string & input_option_name,
bool allow_nonlinear = false )
protected

Declare a parameter.

Similar to the previous methods, but this method takes care of the high-level logic to directly construct a (possibly nonlinear) parameter from the input option.

Template Parameters
TParameter type. See Statically shaped tensors for supported types.
Parameters
nameName of the model parameter.
input_option_nameName of the input option that defines the value of the model parameter.
allow_nonlinearWhether allows coupling with a nonlinear parameter
Returns
T The value of the registered model parameter.

◆ declare_parameter() [3/3]

template<typename T , typename >
const T & declare_parameter ( const std::string & name,
const T & rawval )
protected

Declare a parameter.

Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.

Template Parameters
TBuffer type. See Statically shaped tensors for supported types.
Parameters
nameBuffer name
rawvalBuffer value
Returns
Reference to buffer

◆ get_parameter() [1/2]

TensorValueBase & get_parameter ( const std::string & name)

Get a writable reference of a parameter.

◆ get_parameter() [2/2]

const TensorValueBase & get_parameter ( const std::string & name) const

Get a read-only reference of a parameter.

◆ has_nl_param()

bool has_nl_param ( ) const
inline

Whether this parameter store has any nonlinear parameter.

◆ named_nonlinear_parameter_models()

std::map< std::string, Model * > named_nonlinear_parameter_models ( bool recursive = false) const
virtual

Get all nonlinear parameters' models.

Reimplemented in ComposedModel.

◆ named_nonlinear_parameters()

std::map< std::string, const VariableBase * > named_nonlinear_parameters ( bool recursive = false) const
virtual

Get all nonlinear parameters.

Reimplemented in ComposedModel.

◆ named_parameters() [1/2]

Storage< std::string, TensorValueBase > & named_parameters ( )

◆ named_parameters() [2/2]

const Storage< std::string, TensorValueBase > & named_parameters ( ) const
inline
Returns
the buffer storage

◆ nl_param()

const VariableBase * nl_param ( const std::string & name) const

Query the existence of a nonlinear parameter.

Returns
const VariableBase* Pointer to the VariableBase if the parameter associated with the given parameter name is nonlinear. Returns nullptr otherwise.

◆ operator=() [1/2]

◆ operator=() [2/2]

◆ send_parameters_to()

void send_parameters_to ( const torch::TensorOptions & options)
protectedvirtual

Send parameters to options.

Parameters
optionsThe target options

◆ set_parameter()

void set_parameter ( const std::string & name,
const Tensor & value )

}@

Set the value for a parameter

◆ set_parameters()

void set_parameters ( const std::map< std::string, Tensor > & param_values)

Set values for parameters.

Member Data Documentation

◆ _nl_param_models

std::map<std::string, Model *> _nl_param_models
protected

Map from nonlinear parameter names to models which evaluate them.

◆ _nl_params

std::map<std::string, const VariableBase *> _nl_params
protected

Map from nonlinear parameter names to their corresponding variable views.