LCOV - code coverage report
Current view: top level - user_tensors - FromTorchScript.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 31.0 % 29 9
Test Date: 2025-10-02 16:03:03 Functions: 33.3 % 3 1

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/user_tensors/FromTorchScript.h"
      26              : #include "neml2/tensors/assertions.h"
      27              : 
      28              : #include <torch/script.h>
      29              : #include <torch/serialize.h>
      30              : #include <filesystem>
      31              : 
      32              : namespace fs = std::filesystem;
      33              : 
      34              : namespace neml2
      35              : {
      36              : OptionSet
      37           46 : FromTorchScript::expected_options()
      38              : {
      39           46 :   OptionSet options = UserTensorBase::expected_options();
      40           46 :   options.doc() = "Get the tensor from torch script. The torch scrip should have the "
      41              :                   "named_buffers and the associated tensor. Refer to "
      42           46 :                   "tests/regression/liquid_infiltration/gold/generate_load_file.py for an example";
      43              : 
      44           92 :   options.set<std::string>("pytorch_pt_file");
      45           92 :   options.set("pytorch_pt_file").doc() = "Name of the torch script file.";
      46              : 
      47           92 :   options.set<std::string>("tensor_name");
      48           46 :   options.set("tensor_name").doc() = "Associated named_buffers to extract the tensor from.";
      49           46 :   return options;
      50            0 : }
      51              : 
      52            0 : FromTorchScript::FromTorchScript(const OptionSet & options)
      53            0 :   : UserTensorBase(options)
      54              : {
      55            0 : }
      56              : 
      57              : torch::Tensor
      58            0 : FromTorchScript::load_torch_tensor(const OptionSet & options) const
      59              : {
      60            0 :   const auto fname = fs::path(options.get<std::string>("pytorch_pt_file"));
      61            0 :   const auto tensor_name = options.get<std::string>("tensor_name");
      62            0 :   const auto data = torch::jit::load(fname);
      63              : 
      64            0 :   torch::Tensor t;
      65            0 :   bool found = false;
      66            0 :   for (auto item : data.named_buffers())
      67              :   {
      68            0 :     if (item.name == tensor_name)
      69              :     {
      70            0 :       t = item.value;
      71            0 :       found = true;
      72            0 :       break;
      73              :     }
      74            0 :   }
      75              : 
      76            0 :   neml_assert(found, "No buffer named '", tensor_name, "' in file ", fname);
      77            0 :   t = t.to(torch::kFloat64);
      78            0 :   return t;
      79            0 : }
      80              : } // namespace neml2
        

Generated by: LCOV version 2.0-1