Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/models/R2Multiplication.h"
26 :
27 : namespace neml2
28 : {
29 : register_NEML2_object(R2Multiplication);
30 :
31 : OptionSet
32 2 : R2Multiplication::expected_options()
33 : {
34 2 : OptionSet options = Model::expected_options();
35 2 : options.doc() = "Multiplication of form \\f$ A B \\f$, where \\f$ A \\f$ and \\f$ B \\f$ are "
36 2 : "second order tensors. A and B can be inverted and/or transposed per request.";
37 :
38 4 : options.set<VariableName>("A");
39 4 : options.set("A").doc() = "Variable A";
40 :
41 4 : options.set<bool>("invert_A") = false;
42 4 : options.set("invert_A").doc() = "Whether to invert A";
43 :
44 4 : options.set<bool>("transpose_A") = false;
45 4 : options.set("transpose_A").doc() = "Whether to transpose A";
46 :
47 4 : options.set<VariableName>("B");
48 4 : options.set("B").doc() = "Variable B";
49 :
50 4 : options.set<bool>("invert_B") = false;
51 4 : options.set("invert_B").doc() = "Whether to invert B";
52 :
53 4 : options.set<bool>("transpose_B") = false;
54 4 : options.set("transpose_B").doc() = "Whether to transpose B";
55 :
56 4 : options.set_output("to");
57 2 : options.set("to").doc() = "The result of the multiplication";
58 :
59 2 : return options;
60 0 : }
61 :
62 4 : R2Multiplication::R2Multiplication(const OptionSet & options)
63 : : Model(options),
64 4 : _to(declare_output_variable<R2>("to")),
65 4 : _A(declare_input_variable<R2>("A")),
66 4 : _B(declare_input_variable<R2>("B")),
67 8 : _invA(options.get<bool>("invert_A")),
68 8 : _invB(options.get<bool>("invert_B")),
69 8 : _transA(options.get<bool>("transpose_A")),
70 8 : _transB(options.get<bool>("transpose_B"))
71 : {
72 4 : }
73 :
74 : void
75 8 : R2Multiplication::set_value(bool out, bool dout_din, bool /*d2out_din2*/)
76 : {
77 8 : auto A = _invA ? R2(_A).inverse() : _A;
78 8 : if (_transA)
79 0 : A = A.transpose();
80 :
81 8 : auto B = _invB ? R2(_B).inverse() : _B;
82 8 : if (_transB)
83 0 : B = B.transpose();
84 :
85 8 : const auto AB = A * B;
86 :
87 8 : if (out)
88 4 : _to = AB;
89 :
90 8 : if (dout_din)
91 : {
92 4 : const auto I = R2::identity(_A.options());
93 :
94 4 : if (_invA)
95 : {
96 2 : if (_transA)
97 0 : _to.d(_A) = -A.base_unsqueeze(-2).base_unsqueeze(-3) *
98 0 : AB.transpose().base_unsqueeze(-3).base_unsqueeze(-1);
99 : else
100 4 : _to.d(_A) = -A.base_unsqueeze(-2).base_unsqueeze(-1) *
101 6 : AB.transpose().base_unsqueeze(-3).base_unsqueeze(-2);
102 : }
103 : else
104 : {
105 2 : if (_transA)
106 0 : _to.d(_A) = I.base_unsqueeze(-2).base_unsqueeze(-3) *
107 0 : B.transpose().base_unsqueeze(-3).base_unsqueeze(-1);
108 : else
109 4 : _to.d(_A) = I.base_unsqueeze(-2).base_unsqueeze(-1) *
110 6 : B.transpose().base_unsqueeze(-3).base_unsqueeze(-2);
111 : }
112 :
113 4 : if (_invB)
114 : {
115 2 : if (_transB)
116 0 : _to.d(_B) = -AB.base_unsqueeze(-2).base_unsqueeze(-3) *
117 0 : B.transpose().base_unsqueeze(-3).base_unsqueeze(-1);
118 : else
119 4 : _to.d(_B) = -AB.base_unsqueeze(-2).base_unsqueeze(-1) *
120 6 : B.transpose().base_unsqueeze(-3).base_unsqueeze(-2);
121 : }
122 : else
123 : {
124 2 : if (_transB)
125 0 : _to.d(_B) = A.base_unsqueeze(-2).base_unsqueeze(-3) *
126 0 : I.transpose().base_unsqueeze(-3).base_unsqueeze(-1);
127 : else
128 4 : _to.d(_B) = A.base_unsqueeze(-2).base_unsqueeze(-1) *
129 6 : I.transpose().base_unsqueeze(-3).base_unsqueeze(-2);
130 : }
131 4 : }
132 8 : }
133 : } // namespace neml2
|