27#include "neml2/drivers/Driver.h"
28#include "neml2/tensors/tensors.h"
29#include "neml2/models/map_types.h"
32#include <torch/nn/modules/container/modulelist.h>
33#include <torch/nn/modules/container/moduledict.h>
34#include <torch/serialize.h>
54 void diagnose(std::vector<Diagnosis> &)
const override;
72 virtual torch::nn::ModuleDict
result()
const;
94 virtual void output()
const;
130 void output_pt(
const std::filesystem::path &
out)
const;
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
The Driver drives the execution of a NEML2 Model.
Definition Driver.h:46
Selection of an enum value from a list of candidates.
Definition EnumSelection.h:41
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:58
The base class for all constitutive models.
Definition Model.h:51
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:85
Scalar.
Definition Scalar.h:38
The driver for a transient initial-value problem.
Definition TransientDriver.h:43
const VariableName _time_name
VariableName for the time.
Definition TransientDriver.h:102
std::vector< ValueMap > _result_in
Inputs from all time steps.
Definition TransientDriver.h:125
virtual bool solve()
Solve the initial value problem.
Definition TransientDriver.cxx:177
const torch::Device _device
The device on which to evaluate the model.
Definition TransientDriver.h:99
virtual void advance_step()
Advance in time: the state becomes old state, and forces become old forces.
Definition TransientDriver.cxx:211
const Model & model() const
Definition TransientDriver.h:58
const bool _show_input
Set to true to show model's input axis at the beginning.
Definition TransientDriver.h:120
Model & _model
The model which the driver uses to perform constitutive updates.
Definition TransientDriver.h:97
const bool _show_output
Set to true to show model's output axis at the beginning.
Definition TransientDriver.h:122
void diagnose(std::vector< Diagnosis > &) const override
Check for common problems.
Definition TransientDriver.cxx:130
TransientDriver(const OptionSet &options)
Construct a new TransientDriver object.
Definition TransientDriver.cxx:110
virtual void update_forces()
Update the driving forces for the current time step.
Definition TransientDriver.cxx:231
bool run() override
Let the driver run, return true upon successful completion, and return false otherwise.
Definition TransientDriver.cxx:154
const bool _show_params
Set to true to list all the model parameters at the beginning.
Definition TransientDriver.h:118
Scalar _time
The current time.
Definition TransientDriver.h:104
virtual std::string save_as_path() const
The destination file/path to save the results.
Definition TransientDriver.cxx:300
virtual void apply_predictor()
Apply the predictor to calculate the initial guess for the current time step.
Definition TransientDriver.cxx:252
virtual void store_input()
Save the input of the current time step.
Definition TransientDriver.cxx:294
const EnumSelection _predictor
The predictor used to set the initial guess.
Definition TransientDriver.h:113
const Size _nsteps
Total number of steps.
Definition TransientDriver.h:108
std::vector< ValueMap > _result_out
Outputs from all time steps.
Definition TransientDriver.h:127
virtual torch::nn::ModuleDict result() const
The results (input and output) from all time steps.
Definition TransientDriver.cxx:306
ValueMap _in
The input to the constitutive model.
Definition TransientDriver.h:110
Size _step_count
The current step count.
Definition TransientDriver.h:106
std::string _save_as
The destination file name or file path.
Definition TransientDriver.h:116
virtual void output() const
Save the results into the destination file/path.
Definition TransientDriver.cxx:337
static OptionSet expected_options()
Definition TransientDriver.cxx:61
virtual void apply_ic()
Apply the initial conditions.
Definition TransientDriver.cxx:238
virtual void solve_step()
Perform the constitutive update for the current time step.
Definition TransientDriver.cxx:288
Definition CrossRef.cxx:31
std::map< LabeledAxisAccessor, Tensor > ValueMap
Definition map_types.h:34
int64_t Size
Definition types.h:33