NEML2 2.0.0
|
Wraps a NEML2 model into a `nonlinear.NonlinearRecursiveFunction` Args: model (NEML2 model): the model to wrap Keyword Args: exclude_parameters (list of str): exclude these parameters from being wrapped as a pytorch parameter Additional args and kwargs are forwarded to NonlinearRecursiveFunction (and hence torch.nn.Module) verbatim
Public Member Functions | |
__init__ (self, model, *args, exclude_parameters=list(), **kwargs) | |
forward (self, state, forces) | |
nforce (self) | |
nstate (self) | |
Static Public Attributes | |
lookback = ... | |
str | FORCES = 'forces' |
str | OLD_FORCES = 'old_forces' |
str | OLD_STATE = 'old_state' |
str | PARAMETERS = 'parameters' |
str | RESIDUAL = 'residual' |
str | STATE = 'state' |
list | subaxis_names = ['state', 'old_state', 'forces', 'old_forces', 'residual', 'parameters'] |
Protected Member Functions | |
_adapt_for_pyzag (self, r, J, J_old) | |
_check_model (self) | |
_disassemble_input (self, state, forces) | |
_setup_assemblers (self) | |
_setup_parameters (self, exclude_parameters) | |
_update_parameter_values (self) | |
__init__ | ( | self, | |
model, | |||
* | args, | ||
exclude_parameters = list(), | |||
** | kwargs ) |
|
protected |
Adapt the residual and Jacobians for pyzag pyzag has additional requirements on residual and Jacobians: 1. The residual and Jacobians should have the same batch shape 2. The Jacobians should be square Args: r (neml2.Tensor): residual J (neml2.Tensor): Jacobian J_old (neml2.Tensor): Jacobian for the old state Returns: tuple of torch.Tensor: residual, Jacobian
|
protected |
Simple consistency checks, could be a debug check but we only call this once
|
protected |
Assemble the model input from the flat tensors Args: state (torch.Tensor): tensor containing the model state forces (torch.Tensor): tensor containing the model forces
|
protected |
Setup the assemblers for the state and forces
|
protected |
Mirror parameters of the NEML2 model with torch.nn.Parameter Args: exclude_parameters (list of str): NEML2 parameters to exclude
|
protected |
Copy over new parameter values
forward | ( | self, | |
state, | |||
forces ) |
Actually call the NEML2 model and return the residual and Jacobian Args: state (torch.tensor): tensor with the flattened state forces (torch.tensor): tensor with the flattened forces
nforce | ( | self | ) |
nstate | ( | self | ) |
|
static |
|
static |
|
static |
|
static |
|
static |
|
static |
|
static |
|
static |