LCOV - code coverage report
Current view: top level - drivers - TransientDriver.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 91.5 % 199 182
Test Date: 2025-06-29 01:25:44 Functions: 100.0 % 62 62

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include <torch/nn/modules/container/moduledict.h>
      26              : #include <torch/nn/modules/container/modulelist.h>
      27              : #include <torch/serialize.h>
      28              : 
      29              : #include "neml2/drivers/TransientDriver.h"
      30              : #include "neml2/misc/assertions.h"
      31              : #include "neml2/models/Model.h"
      32              : 
      33              : #ifdef NEML2_HAS_DISPATCHER
      34              : #include "neml2/dispatchers/ValueMapLoader.h"
      35              : #endif
      36              : 
      37              : namespace fs = std::filesystem;
      38              : 
      39              : namespace neml2
      40              : {
      41              : register_NEML2_object(TransientDriver);
      42              : 
      43              : template <typename T>
      44              : void
      45          161 : set_ic(ValueMap & storage,
      46              :        const OptionSet & options,
      47              :        const std::string & name_opt,
      48              :        const std::string & value_opt,
      49              :        const Device & device)
      50              : {
      51          161 :   const auto & names = options.get<std::vector<VariableName>>(name_opt);
      52          161 :   const auto & vals = options.get<std::vector<TensorName<T>>>(value_opt);
      53          161 :   neml_assert(names.size() == vals.size(),
      54              :               "Number of initial condition names ",
      55              :               name_opt,
      56              :               " and number of initial condition values ",
      57              :               value_opt,
      58              :               " should be the same but instead have ",
      59          161 :               names.size(),
      60              :               " and ",
      61          322 :               vals.size(),
      62              :               " respectively.");
      63          161 :   auto * factory = options.get<Factory *>("_factory");
      64          161 :   neml_assert(factory, "Internal error: factory != nullptr");
      65          162 :   for (std::size_t i = 0; i < names.size(); i++)
      66              :   {
      67            1 :     neml_assert(names[i].is_state(),
      68              :                 "Initial condition names should start with 'state' but instead got ",
      69              :                 names[i]);
      70            1 :     storage[names[i]] = vals[i].resolve(factory).to(device);
      71              :   }
      72          161 : }
      73              : 
      74              : template <typename T>
      75              : void
      76          161 : get_force(std::vector<VariableName> & names,
      77              :           std::vector<Tensor> & values,
      78              :           const OptionSet & options,
      79              :           const std::string & name_opt,
      80              :           const std::string & value_opt,
      81              :           const Device & device)
      82              : {
      83          161 :   const auto & force_names = options.get<std::vector<VariableName>>(name_opt);
      84          161 :   const auto & vals = options.get<std::vector<TensorName<T>>>(value_opt);
      85          161 :   neml_assert(force_names.size() == vals.size(),
      86              :               "Number of driving force names ",
      87              :               name_opt,
      88              :               " and number of driving force values ",
      89              :               value_opt,
      90              :               " should be the same but instead have ",
      91          161 :               force_names.size(),
      92              :               " and ",
      93          322 :               vals.size(),
      94              :               " respectively.");
      95          161 :   auto * factory = options.get<Factory *>("_factory");
      96          161 :   neml_assert(factory, "Internal error: factory != nullptr");
      97          161 :   for (std::size_t i = 0; i < force_names.size(); i++)
      98              :   {
      99            0 :     neml_assert(force_names[i].is_force(),
     100              :                 "Driving force names should start with 'forces' but instead got ",
     101              :                 force_names[i]);
     102            0 :     names.push_back(force_names[i]);
     103            0 :     values.push_back(vals[i].resolve(factory).to(device));
     104              :   }
     105          161 : }
     106              : 
     107              : OptionSet
     108            8 : TransientDriver::expected_options()
     109              : {
     110            8 :   OptionSet options = ModelDriver::expected_options();
     111            8 :   options.doc() = "Driver for simulating the transient response of an autonomous system.";
     112              : 
     113           24 :   options.set<VariableName>("time") = VariableName(FORCES, "t");
     114           16 :   options.set("time").doc() = "Time";
     115           16 :   options.set<TensorName<Scalar>>("prescribed_time");
     116           24 :   options.set("prescribed_time").doc() =
     117              :       "Time steps to perform the material update. The times tensor must "
     118            8 :       "have at least one batch dimension representing time steps";
     119              : 
     120           32 :   EnumSelection predictor_selection({"PREVIOUS_STATE", "LINEAR_EXTRAPOLATION"}, "PREVIOUS_STATE");
     121            8 :   options.set<EnumSelection>("predictor") = predictor_selection;
     122           16 :   options.set("predictor").doc() =
     123           16 :       "Predictor used to set the initial guess for each time step. Options are " +
     124           24 :       predictor_selection.candidates_str();
     125              : 
     126              : #define OPTION_IC_(T)                                                                              \
     127              :   options.set<std::vector<VariableName>>("ic_" #T "_names");                                       \
     128              :   options.set("ic_" #T "_names").doc() = "Apply initial conditions to these " #T " variables";     \
     129              :   options.set<std::vector<TensorName<T>>>("ic_" #T "_values");                                     \
     130              :   options.set("ic_" #T "_values").doc() = "Initial condition values for the " #T " variables"
     131         1472 :   FOR_ALL_TENSORBASE(OPTION_IC_);
     132              : 
     133              : #define OPTION_FORCE_(T)                                                                           \
     134              :   options.set<std::vector<VariableName>>("force_" #T "_names");                                    \
     135              :   options.set("force_" #T "_names").doc() = "Prescribed driving force of tensor type " #T;         \
     136              :   options.set<std::vector<TensorName<T>>>("force_" #T "_values");                                  \
     137              :   options.set("force_" #T "_values").doc() = "Prescribed driving force values of tensor type " #T
     138         1472 :   FOR_ALL_TENSORBASE(OPTION_FORCE_);
     139              : 
     140           16 :   options.set<std::string>("save_as");
     141           16 :   options.set("save_as").doc() =
     142            8 :       "File path (absolute or relative to the working directory) to store the results";
     143              : 
     144           16 :   return options;
     145            8 : }
     146              : 
     147            7 : TransientDriver::TransientDriver(const OptionSet & options)
     148              :   : ModelDriver(options),
     149            7 :     _time_name(options.get<VariableName>("time")),
     150           14 :     _time(resolve_tensor<Scalar>("prescribed_time")),
     151            7 :     _nsteps(_time.batch_size(0).concrete()),
     152            7 :     _predictor(options.get<EnumSelection>("predictor")),
     153           14 :     _result_in(_nsteps),
     154           14 :     _result_out(_nsteps),
     155           35 :     _save_as(options.get<std::string>("save_as"))
     156              : {
     157            7 :   _time = _time.to(_device);
     158              : 
     159              : #define SET_IC_(T) set_ic<T>(_ics, options, "ic_" #T "_names", "ic_" #T "_values", _device)
     160          644 :   FOR_ALL_TENSORBASE(SET_IC_);
     161              : 
     162              : #define GET_FORCE_(T)                                                                              \
     163              :   get_force<T>(_driving_force_names,                                                               \
     164              :                _driving_forces,                                                                    \
     165              :                options,                                                                            \
     166              :                "force_" #T "_names",                                                               \
     167              :                "force_" #T "_values",                                                              \
     168              :                _device)
     169          644 :   FOR_ALL_TENSORBASE(GET_FORCE_);
     170            7 : }
     171              : 
     172              : void
     173            7 : TransientDriver::setup()
     174              : {
     175            7 :   ModelDriver::setup();
     176            7 : }
     177              : 
     178              : void
     179            7 : TransientDriver::diagnose() const
     180              : {
     181            7 :   ModelDriver::diagnose();
     182              : 
     183            7 :   diagnostic_assert(
     184            7 :       _time.batch_dim() >= 1,
     185              :       "Input time should have at least one batch dimension but instead has batch dimension ",
     186            7 :       _time.batch_dim());
     187              : 
     188            7 :   for (std::size_t i = 0; i < _driving_forces.size(); i++)
     189              :   {
     190            0 :     diagnostic_assert(_driving_forces[i].batch_dim() >= 1,
     191              :                       "Input driving force ",
     192            0 :                       _driving_force_names[i],
     193              :                       " should have at least one batch dimension but instead has batch dimension ",
     194            0 :                       _driving_forces[i].batch_dim());
     195            0 :     diagnostic_assert(_driving_forces[i].batch_size(0) == _time.batch_size(0),
     196              :                       "Prescribed driving force ",
     197            0 :                       _driving_force_names[i],
     198              :                       " should have the same number of steps "
     199              :                       "as time, but instead has ",
     200            0 :                       _driving_forces[i].batch_size(0),
     201              :                       " steps");
     202              :   }
     203              : 
     204              :   // Check for statefulness
     205            7 :   const auto & input_old_state = _model->input_axis().subaxis(OLD_STATE);
     206            7 :   const auto & output_state = _model->output_axis().subaxis(STATE);
     207            7 :   if (_model->input_axis().has_old_state())
     208           20 :     for (const auto & var : input_old_state.variable_names())
     209           13 :       diagnostic_assert(output_state.has_variable(var),
     210              :                         "Input axis has old state variable ",
     211              :                         var,
     212              :                         ", but the corresponding output state variable doesn't exist.");
     213            7 : }
     214              : 
     215              : bool
     216            7 : TransientDriver::run()
     217              : {
     218            7 :   auto status = solve();
     219              : 
     220            7 :   if (!save_as_path().empty())
     221            5 :     output();
     222              : 
     223            7 :   return status;
     224              : }
     225              : 
     226              : bool
     227            7 : TransientDriver::solve()
     228              : {
     229           77 :   for (_step_count = 0; _step_count < _nsteps; _step_count++)
     230              :   {
     231           70 :     if (_verbose)
     232              :       // LCOV_EXCL_START
     233              :       std::cout << "Step " << _step_count << std::endl;
     234              :     // LCOV_EXCL_STOP
     235              : 
     236           70 :     if (_step_count > 0)
     237           63 :       advance_step();
     238           70 :     update_forces();
     239           70 :     if (_step_count == 0)
     240              :     {
     241            7 :       store_input();
     242            7 :       apply_ic();
     243              :     }
     244              :     else
     245              :     {
     246           63 :       apply_predictor();
     247           63 :       store_input();
     248           63 :       solve_step();
     249              :     }
     250              : 
     251           70 :     if (_verbose)
     252              :       // LCOV_EXCL_START
     253              :       std::cout << std::endl;
     254              :     // LCOV_EXCL_STOP
     255              :   }
     256              : 
     257            7 :   return true;
     258              : }
     259              : 
     260              : void
     261           63 : TransientDriver::advance_step()
     262              : {
     263              :   // State from the previous time step becomes the old state in the current time step
     264           63 :   if (_model->input_axis().has_old_state())
     265              :   {
     266           63 :     const auto input_old_state = _model->input_axis().subaxis(OLD_STATE);
     267          180 :     for (const auto & var : input_old_state.variable_names())
     268          117 :       _in[var.prepend(OLD_STATE)] = _result_out[_step_count - 1][var.prepend(STATE)];
     269           63 :   }
     270              : 
     271              :   // Forces from the previous time step become the old forces in the current time step
     272           63 :   if (_model->input_axis().has_old_forces())
     273              :   {
     274           63 :     const auto input_old_forces = _model->input_axis().subaxis(OLD_FORCES);
     275          171 :     for (const auto & var : input_old_forces.variable_names())
     276          108 :       _in[var.prepend(OLD_FORCES)] = _result_in[_step_count - 1][var.prepend(FORCES)];
     277           63 :   }
     278           63 : }
     279              : 
     280              : void
     281           70 : TransientDriver::update_forces()
     282              : {
     283           70 :   if (_model->input_axis().has_variable(_time_name))
     284          140 :     _in[_time_name] = _time.batch_index({_step_count});
     285              : 
     286           70 :   for (std::size_t i = 0; i < _driving_force_names.size(); i++)
     287            0 :     _in[_driving_force_names[i]] = _driving_forces[i].batch_index({_step_count});
     288          140 : }
     289              : 
     290              : void
     291            7 : TransientDriver::apply_ic()
     292              : {
     293            7 :   _result_out[0] = _ics;
     294              : 
     295              :   // Figure out what the batch size for our default zero ICs should be
     296            7 :   std::vector<Tensor> defined;
     297           22 :   for (const auto & var : _model->output_axis().variable_names())
     298           15 :     if (_result_out[0].count(var))
     299            1 :       defined.push_back(_result_out[0][var]);
     300           24 :   for (const auto & [key, value] : _in)
     301           17 :     defined.push_back(value);
     302            7 :   const auto batch_shape = utils::broadcast_batch_sizes(defined);
     303              : 
     304              :   // Variables without a user-defined IC are initialized to zeros
     305           22 :   for (auto && [name, var] : _model->output_variables())
     306           15 :     if (!_result_out[0].count(name))
     307              :     {
     308           14 :       if (batch_shape.size() > 0)
     309           14 :         _result_out[0][name] =
     310           28 :             Tensor::zeros(utils::add_shapes(var->list_sizes(), var->base_sizes()))
     311           28 :                 .to(_device)
     312           28 :                 .batch_unsqueeze(0)
     313           42 :                 .batch_expand(batch_shape);
     314              :       else
     315            0 :         _result_out[0][name] =
     316            0 :             Tensor::zeros(utils::add_shapes(var->list_sizes(), var->base_sizes())).to(_device);
     317              :     }
     318            7 : }
     319              : 
     320              : void
     321           63 : TransientDriver::apply_predictor()
     322              : {
     323           63 :   if (!_model->input_axis().has_state())
     324           27 :     return;
     325              : 
     326           36 :   const auto input_state = _model->input_axis().subaxis(STATE);
     327          117 :   for (const auto & var : input_state.variable_names())
     328           81 :     if (_model->output_axis().has_variable(var.prepend(STATE)))
     329              :     {
     330          162 :       if (_predictor == "PREVIOUS_STATE")
     331           72 :         _in[var.prepend(STATE)] = _result_out[_step_count - 1][var.prepend(STATE)];
     332           18 :       else if (_predictor == "LINEAR_EXTRAPOLATION")
     333              :       {
     334              :         // Fall back to PREVIOUS_STATE predictor at the 1st time step
     335            9 :         if (_step_count == 1)
     336            1 :           _in[var.prepend(STATE)] = _result_out[_step_count - 1][var.prepend(STATE)];
     337              :         // Otherwise linearly extrapolate in time
     338              :         else
     339              :         {
     340            8 :           const auto t = Scalar(_in[_time_name]);
     341            8 :           const auto t_n = Scalar(_result_in[_step_count - 1][_time_name]);
     342            8 :           const auto t_nm1 = Scalar(_result_in[_step_count - 2][_time_name]);
     343            8 :           const auto dt = t - t_n;
     344            8 :           const auto dt_n = t_n - t_nm1;
     345              : 
     346            8 :           const auto s_n = _result_out[_step_count - 1][var.prepend(STATE)];
     347            8 :           const auto s_nm1 = _result_out[_step_count - 2][var.prepend(STATE)];
     348            8 :           _in[var.prepend(STATE)] = s_n + (s_n - s_nm1) / dt_n * dt;
     349            8 :         }
     350              :       }
     351              :       else
     352            0 :         throw NEMLException("Unrecognized predictor type: " + std::string(_predictor));
     353              :     }
     354           36 : }
     355              : 
     356              : void
     357           63 : TransientDriver::solve_step()
     358              : {
     359              : #ifdef NEML2_HAS_DISPATCHER
     360           63 :   if (_dispatcher)
     361              :   {
     362            0 :     ValueMapLoader loader(_in, 0);
     363            0 :     _result_out[_step_count] = _dispatcher->run(loader);
     364            0 :     return;
     365            0 :   }
     366              : #endif
     367              : 
     368           63 :   _result_out[_step_count] = _model->value((_in));
     369              : }
     370              : 
     371              : void
     372           70 : TransientDriver::store_input()
     373              : {
     374           70 :   _result_in[_step_count] = _in;
     375           70 : }
     376              : 
     377              : std::string
     378           12 : TransientDriver::save_as_path() const
     379              : {
     380           12 :   return _save_as;
     381              : }
     382              : 
     383              : torch::nn::ModuleDict
     384            5 : TransientDriver::result() const
     385              : {
     386              :   // Dump input variables into a ModuleList
     387            5 :   torch::nn::ModuleList res_in;
     388           55 :   for (const auto & in : _result_in)
     389              :   {
     390              :     // Dump input variables at each step into a ModuleDict
     391           50 :     torch::nn::ModuleDict res_in_step;
     392          414 :     for (auto && [name, val] : in)
     393          364 :       res_in_step->register_buffer(utils::stringify(name), val);
     394           50 :     res_in->push_back(res_in_step);
     395           50 :   }
     396              : 
     397              :   // Dump output variables into a ModuleList
     398            5 :   torch::nn::ModuleList res_out;
     399           55 :   for (const auto & out : _result_out)
     400              :   {
     401              :     // Dump output variables at each step into a ModuleDict
     402           50 :     torch::nn::ModuleDict res_out_step;
     403          180 :     for (auto && [name, val] : out)
     404          130 :       res_out_step->register_buffer(utils::stringify(name), val);
     405           50 :     res_out->push_back(res_out_step);
     406           50 :   }
     407              : 
     408              :   // Combine input and output
     409            5 :   torch::nn::ModuleDict res;
     410           25 :   res->update({{"input", res_in.ptr()}, {"output", res_out.ptr()}});
     411           10 :   return res;
     412           10 : }
     413              : 
     414              : void
     415            5 : TransientDriver::output() const
     416              : {
     417            5 :   if (_verbose)
     418              :     // LCOV_EXCL_START
     419              :     std::cout << "Saving results..." << std::endl;
     420              :   // LCOV_EXCL_STOP
     421              : 
     422            5 :   auto cwd = fs::current_path();
     423            5 :   auto out = cwd / save_as_path();
     424              : 
     425            5 :   if (out.extension() == ".pt")
     426            5 :     output_pt(out);
     427              :   else
     428              :     // LCOV_EXCL_START
     429              :     neml_assert(false, "Unsupported output format: ", out.extension());
     430              :   // LCOV_EXCL_STOP
     431              : 
     432            5 :   if (_verbose)
     433              :     // LCOV_EXCL_START
     434              :     std::cout << "Results saved to " << save_as_path() << std::endl;
     435              :   // LCOV_EXCL_STOP
     436            5 : }
     437              : 
     438              : void
     439            5 : TransientDriver::output_pt(const std::filesystem::path & out) const
     440              : {
     441            5 :   torch::save(result(), out);
     442            5 : }
     443              : } // namespace neml2
        

Generated by: LCOV version 2.0-1