LCOV - code coverage report
Current view: top level - drivers - ModelDriver.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 62.1 % 58 36
Test Date: 2025-10-02 16:03:03 Functions: 50.0 % 8 4

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/drivers/ModelDriver.h"
      26              : #include "neml2/misc/assertions.h"
      27              : #include "neml2/models/Model.h"
      28              : 
      29              : #ifdef NEML2_HAS_DISPATCHER
      30              : #include "neml2/dispatchers/valuemap_helpers.h"
      31              : #endif
      32              : 
      33              : namespace neml2
      34              : {
      35              : OptionSet
      36            8 : ModelDriver::expected_options()
      37              : {
      38            8 :   OptionSet options = Driver::expected_options();
      39              : 
      40           16 :   options.set<std::string>("model");
      41           16 :   options.set("model").doc() = "The material model to be updated by the driver";
      42              : 
      43           16 :   options.set<std::string>("device") = "cpu";
      44           24 :   options.set("device").doc() =
      45              :       "Device on which to evaluate the material model. The string supplied must follow the "
      46              :       "following schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, "
      47              :       "and :<device-index> optionally specifies a device index. For example, device='cpu' sets the "
      48              :       "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
      49            8 :       "CUDA with device ID 1.";
      50              : 
      51           16 :   options.set<bool>("show_parameters") = false;
      52           16 :   options.set("show_parameters").doc() = "Whether to show model parameters at the beginning";
      53           16 :   options.set<bool>("show_input_axis") = false;
      54           16 :   options.set("show_input_axis").doc() = "Whether to show model input axis at the beginning";
      55           16 :   options.set<bool>("show_output_axis") = false;
      56           16 :   options.set("show_output_axis").doc() = "Whether to show model output axis at the beginning";
      57              : 
      58              : #ifdef NEML2_HAS_DISPATCHER
      59           16 :   options.set<std::string>("scheduler");
      60           16 :   options.set("scheduler").doc() = "The work scheduler to use";
      61           16 :   options.set<bool>("async_dispatch") = true;
      62            8 :   options.set("async_dispatch").doc() = "Whether to dispatch work asynchronously";
      63              : #endif
      64              : 
      65            8 :   return options;
      66            0 : }
      67              : 
      68            7 : ModelDriver::ModelDriver(const OptionSet & options)
      69              :   : Driver(options),
      70            7 :     _model(get_model("model")),
      71           21 :     _device(options.get<std::string>("device")),
      72           14 :     _show_params(options.get<bool>("show_parameters")),
      73           14 :     _show_input(options.get<bool>("show_input_axis")),
      74           14 :     _show_output(options.get<bool>("show_output_axis"))
      75              : #ifdef NEML2_HAS_DISPATCHER
      76              :     ,
      77            7 :     _scheduler(options.get("scheduler").user_specified() ? get_scheduler("scheduler") : nullptr),
      78           28 :     _async_dispatch(options.get<bool>("async_dispatch"))
      79              : #endif
      80              : {
      81            7 : }
      82              : 
      83              : void
      84            7 : ModelDriver::setup()
      85              : {
      86            7 :   Driver::setup();
      87            7 :   _model->to(_device);
      88              : 
      89              : #ifdef NEML2_HAS_DISPATCHER
      90            7 :   if (_scheduler)
      91              :   {
      92            0 :     auto red = [](std::vector<ValueMap> && results) -> ValueMap
      93            0 :     { return valuemap_cat_reduce(std::move(results), 0); };
      94              : 
      95            0 :     auto post = [this](ValueMap && x) -> ValueMap
      96            0 :     { return valuemap_move_device(std::move(x), _device); };
      97              : 
      98            0 :     auto thread_init = [this](Device device) -> void
      99              :     {
     100            0 :       auto new_factory = Factory(_model->factory()->input_file());
     101            0 :       auto new_model = new_factory.get_model(_model->name());
     102            0 :       new_model->to(device);
     103            0 :       _models[std::this_thread::get_id()] = std::move(new_model);
     104            0 :     };
     105              : 
     106            0 :     _dispatcher = std::make_unique<DispatcherType>(
     107            0 :         *_scheduler,
     108            0 :         _async_dispatch,
     109            0 :         [&](ValueMap && x, Device device) -> ValueMap
     110              :         {
     111            0 :           auto & model = _async_dispatch ? _models[std::this_thread::get_id()] : _model;
     112              : 
     113              :           // If this is not an async dispatch, we need to move the model to the target device
     114              :           // _every_ time before evaluation
     115            0 :           if (!_async_dispatch)
     116            0 :             model->to(device);
     117              : 
     118            0 :           neml_assert_dbg(model->variable_options().device() == device);
     119            0 :           return model->value(std::move(x));
     120              :         },
     121              :         red,
     122            0 :         &valuemap_move_device,
     123              :         post,
     124            0 :         _async_dispatch ? thread_init : std::function<void(Device)>());
     125              :   }
     126              : #endif
     127              : 
     128              :   // LCOV_EXCL_START
     129              :   if (_show_input)
     130              :     std::cout << _model->name() << "'s input axis:\n" << _model->input_axis() << std::endl;
     131              : 
     132              :   if (_show_output)
     133              :     std::cout << _model->name() << "'s output axis:\n" << _model->output_axis() << std::endl;
     134              : 
     135              :   if (_show_params)
     136              :   {
     137              :     std::cout << _model->name() << "'s parameters:" << std::endl;
     138              :     for (auto && [pname, pval] : _model->named_parameters())
     139              :       std::cout << "  " << pname << std::endl;
     140              :   }
     141              :   // LCOV_EXCL_STOP
     142            7 : }
     143              : 
     144              : void
     145            7 : ModelDriver::diagnose() const
     146              : {
     147            7 :   Driver::diagnose();
     148            7 :   neml2::diagnose(*_model);
     149            7 : }
     150              : } // namespace neml2
        

Generated by: LCOV version 2.0-1