|
NEML2 2.0.0
|
Wraps a NEML2 model into a `nonlinear.NonlinearRecursiveFunction`
Args:
model (NEML2 model): the model to wrap
Keyword Args:
exclude_parameters (list of str): exclude these parameters from being wrapped as a pytorch parameter
Additional args and kwargs are forwarded to NonlinearRecursiveFunction (and hence torch.nn.Module) verbatim 
Public Member Functions | |
| __init__ (self, neml2.core.Model model, *args, list[str] exclude_parameters=list(), **kwargs) | |
| forward (self, state, forces) | |
| int | nforce (self) |
| int | nstate (self) |
Static Public Attributes | |
| str | FORCES = 'forces' |
| str | OLD_FORCES = 'old_forces' |
| str | OLD_STATE = 'old_state' |
| str | PARAMETERS = 'parameters' |
| str | RESIDUAL = 'residual' |
| str | STATE = 'state' |
| list | subaxis_names = ['state', 'old_state', 'forces', 'old_forces', 'residual', 'parameters'] |
Protected Member Functions | |
| _adapt_for_pyzag (self, r, J, J_old) | |
| _check_model (self) | |
| _disassemble_input (self, state, forces) | |
| _setup_assemblers (self) | |
| _setup_parameters (self, exclude_parameters) | |
| _update_parameter_values (self) | |
| __init__ | ( | self, | |
| neml2.core.Model | model, | ||
| * | args, | ||
| list[str] | exclude_parameters = list(), | ||
| ** | kwargs ) |
|
protected |
Adapt the residual and Jacobians for pyzag
pyzag has additional requirements on residual and Jacobians:
1. The residual and Jacobians should have the same batch shape
2. The Jacobians should be square
Args:
r (neml2.Tensor): residual
J (neml2.Tensor): Jacobian
J_old (neml2.Tensor): Jacobian for the old state
Returns:
tuple of torch.Tensor: residual, Jacobian
|
protected |
Simple consistency checks, could be a debug check but we only call this once
|
protected |
Disassemble the model input forces, old forces, and old state from the flat tensors
Args:
state (torch.Tensor): tensor containing the model state
forces (torch.Tensor): tensor containing the model forces
|
protected |
Setup the assemblers for the state and forces
|
protected |
Mirror parameters of the NEML2 model with torch.nn.Parameter
Args:
exclude_parameters (list of str): NEML2 parameters to exclude
|
protected |
Copy over new parameter values
| forward | ( | self, | |
| state, | |||
| forces ) |
Actually call the NEML2 model and return the residual and Jacobian
Args:
state (torch.tensor): tensor with the flattened state
forces (torch.tensor): tensor with the flattened forces
| int nforce | ( | self | ) |
| int nstate | ( | self | ) |
|
static |
|
static |
|
static |
|
static |
|
static |
|
static |
|
static |