NEML2 2.0.0
|
Interface for object which can store parameters. More...
Interface for object which can store parameters.
#include <ParameterStore.h>
Public Member Functions | |
ParameterStore (OptionSet options, NEML2Object *object) | |
ParameterStore (const ParameterStore &)=delete | |
ParameterStore (ParameterStore &&)=delete | |
ParameterStore & | operator= (const ParameterStore &)=delete |
ParameterStore & | operator= (ParameterStore &&)=delete |
virtual | ~ParameterStore ()=default |
std::map< std::string, const VariableBase * > | _nl_params |
Map from nonlinear parameter names to their corresponding variable views. | |
std::map< std::string, Model * > | _nl_param_models |
Map from nonlinear parameter names to models which evaluate them. | |
const Storage< std::string, TensorValueBase > & | named_parameters () const |
Storage< std::string, TensorValueBase > & | named_parameters () |
void | set_parameter (const std::string &, const Tensor &) |
}@ | |
void | set_parameters (const std::map< std::string, Tensor > &) |
Set values for parameters. | |
TensorValueBase & | get_parameter (const std::string &name) |
Get a writable reference of a parameter. | |
const TensorValueBase & | get_parameter (const std::string &name) const |
Get a read-only reference of a parameter. | |
bool | has_nl_param () const |
Whether this parameter store has any nonlinear parameter. | |
const VariableBase * | nl_param (const std::string &) const |
Query the existence of a nonlinear parameter. | |
virtual std::map< std::string, const VariableBase * > | named_nonlinear_parameters (bool recursive=false) const |
Get all nonlinear parameters. | |
virtual std::map< std::string, Model * > | named_nonlinear_parameter_models (bool recursive=false) const |
Get all nonlinear parameters' models. | |
virtual void | send_parameters_to (const torch::TensorOptions &options) |
Send parameters to options. | |
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>> | |
const T & | declare_parameter (const std::string &name, const T &rawval) |
Declare a parameter. | |
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>> | |
const T & | declare_parameter (const std::string &name, const CrossRef< T > &crossref, bool allow_nonlinear) |
Declare a parameter. | |
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>> | |
const T & | declare_parameter (const std::string &name, const std::string &input_option_name, bool allow_nonlinear=false) |
Declare a parameter. | |
ParameterStore | ( | OptionSet | options, |
NEML2Object * | object ) |
|
delete |
|
delete |
|
virtualdefault |
|
protected |
Declare a parameter.
Similar to the previous method, but additionally handles the resolution of cross-referenced parameters. Two attempts are made sequentially: first, the method tries to resolve CrossRef<T>
into T
directly; if that fails, the method tries to resolve CrossRef<T>
into a nonlinear parameter where the raw string stored in the cross-ref is treated as the name of the model that defines the nonlinear parameter.
T | Parameter type. See Statically shaped tensors for supported types. |
name | Name of the model parameter. |
crossref | The cross-ref'ed "string" that defines the value of the model parameter. |
allow_nonlinear | Whether allows coupling with a nonlinear parameter |
|
protected |
Declare a parameter.
Similar to the previous methods, but this method takes care of the high-level logic to directly construct a (possibly nonlinear) parameter from the input option.
T | Parameter type. See Statically shaped tensors for supported types. |
name | Name of the model parameter. |
input_option_name | Name of the input option that defines the value of the model parameter. |
allow_nonlinear | Whether allows coupling with a nonlinear parameter |
Declare a parameter.
Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.
T | Buffer type. See Statically shaped tensors for supported types. |
name | Buffer name |
rawval | Buffer value |
TensorValueBase & get_parameter | ( | const std::string & | name | ) |
Get a writable reference of a parameter.
const TensorValueBase & get_parameter | ( | const std::string & | name | ) | const |
Get a read-only reference of a parameter.
|
inline |
Whether this parameter store has any nonlinear parameter.
|
virtual |
Get all nonlinear parameters' models.
Reimplemented in ComposedModel.
|
virtual |
Get all nonlinear parameters.
Reimplemented in ComposedModel.
Storage< std::string, TensorValueBase > & named_parameters | ( | ) |
|
inline |
const VariableBase * nl_param | ( | const std::string & | name | ) | const |
Query the existence of a nonlinear parameter.
|
delete |
|
delete |
Send parameters to options.
options | The target options |
}@
Set the value for a parameter
Set values for parameters.
|
protected |
Map from nonlinear parameter names to models which evaluate them.
|
protected |
Map from nonlinear parameter names to their corresponding variable views.