Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include <torch/autograd.h>
26 :
27 : #include "neml2/tensors/functions/jacrev.h"
28 : #include "neml2/tensors/assertions.h"
29 : #include "neml2/tensors/Tensor.h"
30 : #include "neml2/tensors/Scalar.h"
31 :
32 : namespace neml2
33 : {
34 : std::vector<Tensor>
35 150 : jacrev(const Tensor & y,
36 : const std::vector<Tensor> & xs,
37 : bool retain_graph,
38 : bool create_graph,
39 : bool allow_unused)
40 : {
41 150 : std::vector<Tensor> dy_dxs(xs.size());
42 :
43 : // Return undefined Tensor if y does not contain any gradient graph
44 150 : if (!y.requires_grad())
45 17 : return dy_dxs;
46 :
47 : // Check batch shapes
48 276 : for (std::size_t i = 0; i < xs.size(); i++)
49 143 : neml_assert_dbg(y.batch_sizes() == xs[i].batch_sizes(),
50 : "In jacrev, the output variable batch shape ",
51 : y.batch_sizes(),
52 : " is different from the batch shape of x[",
53 : i,
54 : "] ",
55 143 : xs[i].batch_sizes(),
56 : ".");
57 :
58 133 : const auto opt = y.options().requires_grad(false);
59 :
60 : // Flatten y to handle arbitrarily shaped output
61 133 : const auto yf = y.base_flatten();
62 133 : const auto G = Scalar::full(1.0, opt).batch_expand(yf.batch_sizes());
63 :
64 : // Initialize derivatives to zero
65 266 : std::vector<ATensor> xts(xs.begin(), xs.end());
66 133 : std::vector<Tensor> dyf_dxs(xs.size());
67 276 : for (std::size_t i = 0; i < xs.size(); i++)
68 286 : dyf_dxs[i] = Tensor::zeros(
69 429 : yf.batch_sizes(), utils::add_shapes(yf.base_size(0), xs[i].base_sizes()), opt);
70 :
71 : // Use autograd to calculate the derivatives
72 402 : for (Size i = 0; i < yf.base_size(0); i++)
73 : {
74 1345 : const auto dyfi_dxs = torch::autograd::grad({yf.base_index({i})},
75 : {xts},
76 : {G},
77 : /*retain_graph=*/retain_graph,
78 : /*create_graph=*/create_graph,
79 2152 : /*allow_unused=*/allow_unused);
80 269 : neml_assert_dbg(dyfi_dxs.size() == xs.size(),
81 : "In jacrev, the number of derivatives is ",
82 269 : dyfi_dxs.size(),
83 : " but the number of input tensors is ",
84 269 : xs.size(),
85 : ".");
86 563 : for (std::size_t j = 0; j < xs.size(); j++)
87 294 : if (dyfi_dxs[j].defined())
88 873 : dyf_dxs[j].base_index_put_({Size(i)}, dyfi_dxs[j]);
89 269 : }
90 :
91 : // Reshape the derivative back to the correct shape
92 276 : for (std::size_t i = 0; i < xs.size(); i++)
93 143 : dy_dxs[i] = dyf_dxs[i].base_reshape(utils::add_shapes(y.base_sizes(), xs[i].base_sizes()));
94 :
95 133 : return dy_dxs;
96 1231 : }
97 :
98 : Tensor
99 146 : jacrev(const Tensor & y, const Tensor & x, bool retain_graph, bool create_graph, bool allow_unused)
100 : {
101 438 : return jacrev(y, std::vector<Tensor>{x}, retain_graph, create_graph, allow_unused)[0];
102 146 : }
103 : } // namespace neml2
|