Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/user_tensors/FromTorchScript.h"
26 : #include "neml2/tensors/assertions.h"
27 :
28 : #include <torch/script.h>
29 : #include <torch/serialize.h>
30 : #include <filesystem>
31 :
32 : namespace fs = std::filesystem;
33 :
34 : namespace neml2
35 : {
36 : OptionSet
37 46 : FromTorchScript::expected_options()
38 : {
39 46 : OptionSet options = UserTensorBase::expected_options();
40 46 : options.doc() = "Get the tensor from torch script. The torch scrip should have the "
41 : "named_buffers and the associated tensor. Refer to "
42 46 : "tests/regression/liquid_infiltration/gold/generate_load_file.py for an example";
43 :
44 92 : options.set<std::string>("pytorch_pt_file");
45 92 : options.set("pytorch_pt_file").doc() = "Name of the torch script file.";
46 :
47 92 : options.set<std::string>("tensor_name");
48 46 : options.set("tensor_name").doc() = "Associated named_buffers to extract the tensor from.";
49 46 : return options;
50 0 : }
51 :
52 0 : FromTorchScript::FromTorchScript(const OptionSet & options)
53 0 : : UserTensorBase(options)
54 : {
55 0 : }
56 :
57 : torch::Tensor
58 0 : FromTorchScript::load_torch_tensor(const OptionSet & options) const
59 : {
60 0 : const auto fname = fs::path(options.get<std::string>("pytorch_pt_file"));
61 0 : const auto tensor_name = options.get<std::string>("tensor_name");
62 0 : const auto data = torch::jit::load(fname);
63 :
64 0 : torch::Tensor t;
65 0 : bool found = false;
66 0 : for (auto item : data.named_buffers())
67 : {
68 0 : if (item.name == tensor_name)
69 : {
70 0 : t = item.value;
71 0 : found = true;
72 0 : break;
73 : }
74 0 : }
75 :
76 0 : neml_assert(found, "No buffer named '", tensor_name, "' in file ", fname);
77 0 : t = t.to(torch::kFloat64);
78 0 : return t;
79 0 : }
80 : } // namespace neml2
|