Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/models/BilinearInterpolation.h"
26 : #include "neml2/tensors/Tensor.h"
27 : #include "neml2/tensors/functions/diff.h"
28 : #include "neml2/tensors/Scalar.h"
29 : #include "neml2/tensors/Vec.h"
30 : #include "neml2/tensors/SR2.h"
31 : #include "neml2/tensors/indexing.h"
32 :
33 : namespace neml2
34 : {
35 : template <typename T>
36 : OptionSet
37 6 : BilinearInterpolation<T>::expected_options()
38 : {
39 6 : OptionSet options = Interpolation<T>::expected_options();
40 6 : options.doc() += " This object performs a _bilinear interpolation_.";
41 :
42 12 : options.set<bool>("define_second_derivatives") = true;
43 :
44 12 : options.set<TensorName<Scalar>>("abscissa1");
45 12 : options.set("abscissa1").doc() =
46 : "Scalar defining the abscissa values of the first interpolation axis";
47 :
48 12 : options.set<TensorName<Scalar>>("abscissa2");
49 12 : options.set("abscissa2").doc() =
50 : "Scalar defining the abscissa values of the second interpolation axis";
51 :
52 12 : options.set_input("argument1");
53 12 : options.set("argument1").doc() =
54 : "First argument used to query the interpolant along the first axis";
55 :
56 12 : options.set_input("argument2");
57 6 : options.set("argument2").doc() =
58 : "Second argument used to query the interpolant along the second axis";
59 :
60 6 : return options;
61 0 : }
62 :
63 : template <typename T>
64 4 : BilinearInterpolation<T>::BilinearInterpolation(const OptionSet & options)
65 : : Interpolation<T>(options),
66 12 : _X1(this->template declare_parameter<Scalar>("X1", "abscissa1")),
67 16 : _X2(this->template declare_parameter<Scalar>("X2", "abscissa2")),
68 4 : _x1(this->template declare_input_variable<Scalar>("argument1")),
69 8 : _x2(this->template declare_input_variable<Scalar>("argument2"))
70 : {
71 4 : }
72 :
73 : static std::tuple<Scalar, Scalar, Scalar>
74 24 : parametric_coordinates(const Scalar & X, const Scalar & x)
75 : {
76 : using namespace indexing;
77 : const auto m =
78 120 : at::logical_and(at::gt(x.batch_unsqueeze(-1), X.index({Ellipsis, Slice(None, -1)})),
79 144 : at::le(x.batch_unsqueeze(-1), X.index({Ellipsis, Slice(1)})));
80 168 : const auto X_start = X.index({Ellipsis, Slice(None, -1)}).expand_as(m).index({m});
81 168 : const auto X_end = X.index({Ellipsis, Slice(1)}).expand_as(m).index({m});
82 24 : const auto xi = (x - X_start) / (X_end - X_start);
83 24 : const auto dxi = 1.0 / (X_end - X_start);
84 48 : return {Scalar(m), Scalar(xi), Scalar(dxi)};
85 312 : }
86 :
87 : template <typename T>
88 : static T
89 48 : apply_mask(const T & y, const Scalar & m)
90 : {
91 192 : const auto B = utils::broadcast_batch_sizes({m, y});
92 48 : const auto D = B.slice(0, -2); // excluding the interpolation grid
93 192 : return T(y.batch_expand(B).index({m.batch_expand(B)})).batch_reshape(D);
94 144 : }
95 :
96 : template <typename T>
97 : void
98 12 : BilinearInterpolation<T>::set_value(bool out, bool dout_din, bool d2out_din2)
99 : {
100 : using namespace indexing;
101 :
102 : // Get masks for the interpolating cell on the 2D grid
103 : // Also transform x onto the parametric space [0, 1] x [0, 1]
104 12 : const auto [m1, xi, dxi_dx1] = parametric_coordinates(this->_X1, Scalar(this->_x1));
105 12 : const auto [m2, eta, deta_dx2] = parametric_coordinates(this->_X2, Scalar(this->_x2));
106 12 : auto m = Scalar(at::logical_and(m1.unsqueeze(-1), m2.unsqueeze(-2)));
107 :
108 : // Get the four corner values of the interpolating cell
109 : //
110 : // Y01 ------ Y11
111 : // | |
112 : // | |
113 : // | |
114 : // Y00 ------ Y10
115 72 : auto Y00 = this->_Y.batch_index({Ellipsis, Slice(None, -1), Slice(None, -1)});
116 72 : auto Y01 = this->_Y.batch_index({Ellipsis, Slice(None, -1), Slice(1)});
117 72 : auto Y10 = this->_Y.batch_index({Ellipsis, Slice(1), Slice(None, -1)});
118 72 : auto Y11 = this->_Y.batch_index({Ellipsis, Slice(1), Slice(1)});
119 12 : Y00 = apply_mask(Y00, m);
120 12 : Y01 = apply_mask(Y01, m);
121 12 : Y10 = apply_mask(Y10, m);
122 12 : Y11 = apply_mask(Y11, m);
123 :
124 : // The interpolation formula is:
125 : // p = Y00 + c1 * xi + c2 * eta + c3 * xi * eta
126 : // where c1 = (Y10 - Y00)
127 : // c2 = (Y01 - Y00)
128 : // c3 = (Y11 - Y10 - Y01 + Y00)
129 12 : const auto c1 = Y10 - Y00;
130 12 : const auto c2 = Y01 - Y00;
131 12 : const auto c3 = Y11 - Y10 - Y01 + Y00;
132 :
133 12 : if (out)
134 4 : this->_p = Y00 + c1 * xi + c2 * eta + c3 * xi * eta;
135 :
136 12 : if (dout_din)
137 : {
138 4 : if (this->_x1.is_dependent())
139 4 : this->_p.d(this->_x1) = (c1 + c3 * eta) * dxi_dx1;
140 4 : if (this->_x2.is_dependent())
141 4 : this->_p.d(this->_x2) = (c2 + c3 * xi) * deta_dx2;
142 : }
143 :
144 12 : if (d2out_din2)
145 4 : if (this->_x1.is_dependent() && this->_x2.is_dependent())
146 : {
147 4 : this->_p.d(this->_x1, this->_x2) = c3 * dxi_dx1 * deta_dx2;
148 4 : this->_p.d(this->_x2, this->_x1) = c3 * dxi_dx1 * deta_dx2;
149 : }
150 156 : }
151 :
152 : #define REGISTER(T) \
153 : using T##BilinearInterpolation = BilinearInterpolation<T>; \
154 : register_NEML2_object(T##BilinearInterpolation); \
155 : template class BilinearInterpolation<T>
156 : REGISTER(Scalar);
157 : REGISTER(Vec);
158 : REGISTER(SR2);
159 : } // namespace neml2
|