LCOV - code coverage report
Current view: top level - drivers - ModelDriver.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 51.9 % 79 41
Test Date: 2025-06-29 01:25:44 Functions: 40.0 % 10 4

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/drivers/ModelDriver.h"
      26              : #include "neml2/misc/assertions.h"
      27              : #include "neml2/models/Model.h"
      28              : 
      29              : #ifdef NEML2_HAS_DISPATCHER
      30              : #include "neml2/dispatchers/valuemap_helpers.h"
      31              : #endif
      32              : 
      33              : namespace neml2
      34              : {
      35              : void
      36            0 : details_callback(const Model & model,
      37              :                  const std::map<VariableName, std::unique_ptr<VariableBase>> & input,
      38              :                  const std::map<VariableName, std::unique_ptr<VariableBase>> & output)
      39              : {
      40            0 :   std::cout << model.name() << std::endl;
      41            0 :   std::cout << "\tInput" << std::endl;
      42            0 :   for (const auto & pair : input)
      43              :   {
      44            0 :     std::cout << "\t\t" << pair.first.str() << " (" << pair.second->sizes() << ") -> "
      45            0 :               << at::norm(pair.second->tensor()).cpu().item<double>() << std::endl;
      46              :   }
      47            0 :   std::cout << "\tOutput" << std::endl;
      48            0 :   for (const auto & pair : output)
      49              :   {
      50            0 :     std::cout << "\t\t" << pair.first.str() << " (" << pair.second->sizes() << ") -> "
      51            0 :               << at::norm(pair.second->tensor()).cpu().item<double>() << std::endl;
      52              :   }
      53            0 : }
      54              : 
      55              : OptionSet
      56            8 : ModelDriver::expected_options()
      57              : {
      58            8 :   OptionSet options = Driver::expected_options();
      59              : 
      60           16 :   options.set<std::string>("model");
      61           16 :   options.set("model").doc() = "The material model to be updated by the driver";
      62              : 
      63           16 :   options.set<std::string>("device") = "cpu";
      64           24 :   options.set("device").doc() =
      65              :       "Device on which to evaluate the material model. The string supplied must follow the "
      66              :       "following schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, "
      67              :       "and :<device-index> optionally specifies a device index. For example, device='cpu' sets the "
      68              :       "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
      69            8 :       "CUDA with device ID 1.";
      70              : 
      71           16 :   options.set<bool>("show_parameters") = false;
      72           16 :   options.set("show_parameters").doc() = "Whether to show model parameters at the beginning";
      73           16 :   options.set<bool>("show_input_axis") = false;
      74           16 :   options.set("show_input_axis").doc() = "Whether to show model input axis at the beginning";
      75           16 :   options.set<bool>("show_output_axis") = false;
      76           16 :   options.set("show_output_axis").doc() = "Whether to show model output axis at the beginning";
      77              : 
      78           16 :   options.set<bool>("log_details") = false;
      79           24 :   options.set("log_details").doc() =
      80            8 :       "If true attach a callback which outputs lots of information on the model execution";
      81              : 
      82              : #ifdef NEML2_HAS_DISPATCHER
      83           16 :   options.set<std::string>("scheduler");
      84           16 :   options.set("scheduler").doc() = "The work scheduler to use";
      85           16 :   options.set<bool>("async_dispatch") = true;
      86            8 :   options.set("async_dispatch").doc() = "Whether to dispatch work asynchronously";
      87              : #endif
      88              : 
      89            8 :   return options;
      90            0 : }
      91              : 
      92            7 : ModelDriver::ModelDriver(const OptionSet & options)
      93              :   : Driver(options),
      94            7 :     _model(get_model("model")),
      95           21 :     _device(options.get<std::string>("device")),
      96           14 :     _show_params(options.get<bool>("show_parameters")),
      97           14 :     _show_input(options.get<bool>("show_input_axis")),
      98           14 :     _show_output(options.get<bool>("show_output_axis")),
      99           14 :     _log_details(options.get<bool>("log_details"))
     100              : #ifdef NEML2_HAS_DISPATCHER
     101              :     ,
     102            7 :     _scheduler(options.get("scheduler").user_specified() ? get_scheduler("scheduler") : nullptr),
     103           28 :     _async_dispatch(options.get<bool>("async_dispatch"))
     104              : #endif
     105              : {
     106            7 : }
     107              : 
     108              : void
     109            7 : ModelDriver::setup()
     110              : {
     111            7 :   Driver::setup();
     112              : 
     113              :   // Send model parameters and buffers to device
     114            7 :   _model->to(_device);
     115              : 
     116            7 :   if (_log_details)
     117            0 :     _model->register_callback_recursive(details_callback);
     118              : 
     119              : #ifdef NEML2_HAS_DISPATCHER
     120            7 :   if (_scheduler)
     121              :   {
     122            0 :     auto red = [](std::vector<ValueMap> && results) -> ValueMap
     123            0 :     { return valuemap_cat_reduce(std::move(results), 0); };
     124              : 
     125            0 :     auto post = [this](ValueMap && x) -> ValueMap
     126            0 :     { return valuemap_move_device(std::move(x), _device); };
     127              : 
     128            0 :     auto thread_init = [this](Device device) -> void
     129              :     {
     130            0 :       auto new_factory = Factory(_model->factory()->input_file());
     131            0 :       auto new_model = new_factory.get_model(_model->name());
     132            0 :       new_model->to(device);
     133            0 :       _models[std::this_thread::get_id()] = std::move(new_model);
     134            0 :     };
     135              : 
     136            0 :     _dispatcher = std::make_unique<DispatcherType>(
     137            0 :         *_scheduler,
     138            0 :         _async_dispatch,
     139            0 :         [&](ValueMap && x, Device device) -> ValueMap
     140              :         {
     141            0 :           auto & model = _async_dispatch ? _models[std::this_thread::get_id()] : _model;
     142              : 
     143              :           // If this is not an async dispatch, we need to move the model to the target device
     144              :           // _every_ time before evaluation
     145            0 :           if (!_async_dispatch)
     146            0 :             model->to(device);
     147              : 
     148            0 :           neml_assert_dbg(model->variable_options().device() == device);
     149            0 :           return model->value(std::move(x));
     150              :         },
     151              :         red,
     152            0 :         &valuemap_move_device,
     153              :         post,
     154            0 :         _async_dispatch ? thread_init : std::function<void(Device)>());
     155              :   }
     156              : #endif
     157              : 
     158              :   // LCOV_EXCL_START
     159              :   if (_show_input)
     160              :     std::cout << _model->name() << "'s input axis:\n" << _model->input_axis() << std::endl;
     161              : 
     162              :   if (_show_output)
     163              :     std::cout << _model->name() << "'s output axis:\n" << _model->output_axis() << std::endl;
     164              : 
     165              :   if (_show_params)
     166              :   {
     167              :     std::cout << _model->name() << "'s parameters:" << std::endl;
     168              :     for (auto && [pname, pval] : _model->named_parameters())
     169              :       std::cout << "  " << pname << std::endl;
     170              :   }
     171              :   // LCOV_EXCL_STOP
     172            7 : }
     173              : 
     174              : void
     175            7 : ModelDriver::diagnose() const
     176              : {
     177            7 :   Driver::diagnose();
     178            7 :   neml2::diagnose(*_model);
     179            7 : }
     180              : 
     181              : void
     182            0 : ModelDriver::to(Device dev)
     183              : {
     184            0 :   _device = dev;
     185            0 :   setup();
     186            0 : }
     187              : 
     188              : } // namespace neml2
        

Generated by: LCOV version 2.0-1