NEML2 2.0.0
Loading...
Searching...
No Matches
NEML2PyzagModel Class Reference

Detailed Description

Wraps a NEML2 model into a `nonlinear.NonlinearRecursiveFunction`

    Args:
        model (NEML2 model): the model to wrap

    Keyword Args:
        exclude_parameters (list of str): exclude these parameters from being wrapped as a pytorch parameter

    Additional args and kwargs are forwarded to NonlinearRecursiveFunction (and hence torch.nn.Module) verbatim
Inheritance diagram for NEML2PyzagModel:

Public Member Functions

 __init__ (self, model, *args, exclude_parameters=list(), **kwargs)
 
 forward (self, state, forces)
 
 nforce (self)
 
 nstate (self)
 

Static Public Attributes

 lookback = ...
 
str FORCES = 'forces'
 
str OLD_FORCES = 'old_forces'
 
str OLD_STATE = 'old_state'
 
str PARAMETERS = 'parameters'
 
str RESIDUAL = 'residual'
 
str STATE = 'state'
 
list subaxis_names = ['state', 'old_state', 'forces', 'old_forces', 'residual', 'parameters']
 

Protected Member Functions

 _adapt_for_pyzag (self, r, J, J_old)
 
 _check_model (self)
 
 _disassemble_input (self, state, forces)
 
 _setup_assemblers (self)
 
 _setup_parameters (self, exclude_parameters)
 
 _update_parameter_values (self)
 

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
model,
* args,
exclude_parameters = list(),
** kwargs )

Member Function Documentation

◆ _adapt_for_pyzag()

_adapt_for_pyzag ( self,
r,
J,
J_old )
protected
Adapt the residual and Jacobians for pyzag

        pyzag has additional requirements on residual and Jacobians:
          1. The residual and Jacobians should have the same batch shape
          2. The Jacobians should be square

        Args:
            r (neml2.Tensor): residual
            J (neml2.Tensor): Jacobian
            J_old (neml2.Tensor): Jacobian for the old state

        Returns:
            tuple of torch.Tensor: residual, Jacobian

◆ _check_model()

_check_model ( self)
protected
Simple consistency checks, could be a debug check but we only call this once

◆ _disassemble_input()

_disassemble_input ( self,
state,
forces )
protected
Assemble the model input from the flat tensors

        Args:
            state (torch.Tensor): tensor containing the model state
            forces (torch.Tensor): tensor containing the model forces

◆ _setup_assemblers()

_setup_assemblers ( self)
protected
Setup the assemblers for the state and forces

◆ _setup_parameters()

_setup_parameters ( self,
exclude_parameters )
protected
Mirror parameters of the NEML2 model with torch.nn.Parameter

        Args:
            exclude_parameters (list of str): NEML2 parameters to exclude

◆ _update_parameter_values()

_update_parameter_values ( self)
protected
Copy over new parameter values

◆ forward()

forward ( self,
state,
forces )
Actually call the NEML2 model and return the residual and Jacobian

        Args:
            state (torch.tensor): tensor with the flattened state
            forces (torch.tensor): tensor with the flattened forces

◆ nforce()

nforce ( self)

◆ nstate()

nstate ( self)

Member Data Documentation

◆ FORCES

str FORCES = 'forces'
static

◆ lookback

lookback = ...
static

◆ OLD_FORCES

str OLD_FORCES = 'old_forces'
static

◆ OLD_STATE

str OLD_STATE = 'old_state'
static

◆ PARAMETERS

str PARAMETERS = 'parameters'
static

◆ RESIDUAL

str RESIDUAL = 'residual'
static

◆ STATE

str STATE = 'state'
static

◆ subaxis_names

list subaxis_names = ['state', 'old_state', 'forces', 'old_forces', 'residual', 'parameters']
static