Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/tensors/R2Base.h"
26 : #include "neml2/tensors/R2.h"
27 : #include "neml2/tensors/Scalar.h"
28 : #include "neml2/tensors/Vec.h"
29 : #include "neml2/tensors/SR2.h"
30 : #include "neml2/tensors/R3.h"
31 : #include "neml2/tensors/R4.h"
32 : #include "neml2/tensors/Rot.h"
33 : #include "neml2/tensors/WR2.h"
34 : #include "neml2/tensors/assertions.h"
35 : #include "neml2/tensors/functions/stack.h"
36 : #include "neml2/tensors/functions/sum.h"
37 :
38 : namespace neml2
39 : {
40 : template <class Derived>
41 : Derived
42 26 : R2Base<Derived>::fill(const CScalar & a, const TensorOptions & options)
43 : {
44 26 : return R2Base<Derived>::fill(Scalar(a, options));
45 : }
46 :
47 : template <class Derived>
48 : Derived
49 49 : R2Base<Derived>::fill(const Scalar & a)
50 : {
51 49 : auto zero = Scalar::zeros_like(a);
52 294 : return Derived(base_stack({base_stack({a, zero, zero}, -1),
53 196 : base_stack({zero, a, zero}, -1),
54 196 : base_stack({zero, zero, a}, -1)},
55 294 : -2));
56 735 : }
57 :
58 : template <class Derived>
59 : Derived
60 2 : R2Base<Derived>::fill(const CScalar & a11,
61 : const CScalar & a22,
62 : const CScalar & a33,
63 : const TensorOptions & options)
64 : {
65 2 : return R2Base<Derived>::fill(Scalar(a11, options), Scalar(a22, options), Scalar(a33, options));
66 : }
67 :
68 : template <class Derived>
69 : Derived
70 5 : R2Base<Derived>::fill(const Scalar & a11, const Scalar & a22, const Scalar & a33)
71 : {
72 5 : auto zero = Scalar::zeros_like(a11);
73 30 : return Derived(base_stack({base_stack({a11, zero, zero}, -1),
74 20 : base_stack({zero, a22, zero}, -1),
75 20 : base_stack({zero, zero, a33}, -1)},
76 30 : -2));
77 75 : }
78 :
79 : template <class Derived>
80 : Derived
81 2 : R2Base<Derived>::fill(const CScalar & a11,
82 : const CScalar & a22,
83 : const CScalar & a33,
84 : const CScalar & a23,
85 : const CScalar & a13,
86 : const CScalar & a12,
87 : const TensorOptions & options)
88 : {
89 4 : return R2Base<Derived>::fill(Scalar(a11, options),
90 4 : Scalar(a22, options),
91 4 : Scalar(a33, options),
92 4 : Scalar(a23, options),
93 4 : Scalar(a13, options),
94 6 : Scalar(a12, options));
95 : }
96 :
97 : template <class Derived>
98 : Derived
99 5 : R2Base<Derived>::fill(const Scalar & a11,
100 : const Scalar & a22,
101 : const Scalar & a33,
102 : const Scalar & a23,
103 : const Scalar & a13,
104 : const Scalar & a12)
105 : {
106 30 : return Derived(base_stack({base_stack({a11, a12, a13}, -1),
107 20 : base_stack({a12, a22, a23}, -1),
108 20 : base_stack({a13, a23, a33}, -1)},
109 30 : -2));
110 75 : }
111 :
112 : template <class Derived>
113 : Derived
114 34 : R2Base<Derived>::fill(const CScalar & a11,
115 : const CScalar & a12,
116 : const CScalar & a13,
117 : const CScalar & a21,
118 : const CScalar & a22,
119 : const CScalar & a23,
120 : const CScalar & a31,
121 : const CScalar & a32,
122 : const CScalar & a33,
123 : const TensorOptions & options)
124 : {
125 68 : return R2Base<Derived>::fill(Scalar(a11, options),
126 68 : Scalar(a12, options),
127 68 : Scalar(a13, options),
128 68 : Scalar(a21, options),
129 68 : Scalar(a22, options),
130 68 : Scalar(a23, options),
131 68 : Scalar(a31, options),
132 68 : Scalar(a32, options),
133 102 : Scalar(a33, options));
134 : }
135 :
136 : template <class Derived>
137 : Derived
138 140 : R2Base<Derived>::fill(const Scalar & a11,
139 : const Scalar & a12,
140 : const Scalar & a13,
141 : const Scalar & a21,
142 : const Scalar & a22,
143 : const Scalar & a23,
144 : const Scalar & a31,
145 : const Scalar & a32,
146 : const Scalar & a33)
147 : {
148 840 : return Derived(base_stack({base_stack({a11, a12, a13}, -1),
149 560 : base_stack({a21, a22, a23}, -1),
150 560 : base_stack({a31, a32, a33}, -1)},
151 840 : -2));
152 2100 : }
153 :
154 : template <class Derived>
155 : Derived
156 305 : R2Base<Derived>::skew(const Vec & v)
157 : {
158 305 : const auto z = Scalar::zeros_like(v(0));
159 1830 : return Derived(base_stack({base_stack({z, -v(2), v(1)}, -1),
160 1220 : base_stack({v(2), z, -v(0)}, -1),
161 1220 : base_stack({-v(1), v(0), z}, -1)},
162 1830 : -2));
163 4575 : }
164 :
165 : template <class Derived>
166 : Derived
167 381 : R2Base<Derived>::identity(const TensorOptions & options)
168 : {
169 381 : return Derived(at::eye(3, options), 0);
170 : }
171 :
172 : template <class Derived>
173 : Derived
174 18 : R2Base<Derived>::rotate(const Rot & r) const
175 : {
176 18 : return rotate(r.euler_rodrigues());
177 : }
178 :
179 : template <class Derived>
180 : Derived
181 43 : R2Base<Derived>::rotate(const R2 & R) const
182 : {
183 43 : return R * R2(*this) * R.transpose();
184 : }
185 :
186 : template <class Derived>
187 : R3
188 11 : R2Base<Derived>::drotate(const Rot & r) const
189 : {
190 11 : auto R = r.euler_rodrigues();
191 11 : auto F = r.deuler_rodrigues();
192 :
193 44 : return R3(at::einsum("...itl,...tm,...jm", {F, *this, R}) +
194 110 : at::einsum("...ik,...kt,...jtl", {R, *this, F}));
195 33 : }
196 :
197 : template <class Derived>
198 : R4
199 9 : R2Base<Derived>::drotate(const R2 & R) const
200 : {
201 9 : auto I = R2::identity(R.options());
202 27 : return R4(at::einsum("...ik,...jl", {I, R * this->transpose()}) +
203 90 : at::einsum("...jk,...il", {I, R * R2(*this)}));
204 45 : }
205 :
206 : template <class Derived>
207 : Scalar
208 81 : R2Base<Derived>::operator()(Size i, Size j) const
209 : {
210 243 : return PrimitiveTensor<Derived, 3, 3>::base_index({i, j});
211 81 : }
212 :
213 : template <class Derived>
214 : Scalar
215 4 : R2Base<Derived>::det() const
216 : {
217 4 : const auto comps = at::split(this->base_flatten(), 1, -1);
218 4 : const auto & a = comps[0];
219 4 : const auto & b = comps[1];
220 4 : const auto & c = comps[2];
221 4 : const auto & d = comps[3];
222 4 : const auto & e = comps[4];
223 4 : const auto & f = comps[5];
224 4 : const auto & g = comps[6];
225 4 : const auto & h = comps[7];
226 4 : const auto & i = comps[8];
227 4 : const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
228 8 : return Scalar(det.reshape(this->batch_sizes().concrete()), this->batch_sizes());
229 4 : }
230 :
231 : template <class Derived>
232 : Scalar
233 2 : R2Base<Derived>::inner(const R2 & other) const
234 : {
235 2 : return base_sum(this->base_flatten() * other.base_flatten());
236 : }
237 :
238 : template <class Derived>
239 : Derived
240 97 : R2Base<Derived>::inverse() const
241 : {
242 97 : const auto comps = at::split(this->base_flatten(), 1, -1);
243 97 : const auto & a = comps[0];
244 97 : const auto & b = comps[1];
245 97 : const auto & c = comps[2];
246 97 : const auto & d = comps[3];
247 97 : const auto & e = comps[4];
248 97 : const auto & f = comps[5];
249 97 : const auto & g = comps[6];
250 97 : const auto & h = comps[7];
251 97 : const auto & i = comps[8];
252 97 : const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
253 97 : const auto cof00 = e * i - h * f;
254 97 : const auto cof01 = -(d * i - g * f);
255 97 : const auto cof02 = d * h - g * e;
256 97 : const auto cof10 = -(b * i - h * c);
257 97 : const auto cof11 = a * i - g * c;
258 97 : const auto cof12 = -(a * h - g * b);
259 97 : const auto cof20 = b * f - e * c;
260 97 : const auto cof21 = -(a * f - d * c);
261 97 : const auto cof22 = a * e - d * b;
262 388 : const auto coft0 = at::cat({cof00, cof10, cof20}, -1);
263 388 : const auto coft1 = at::cat({cof01, cof11, cof21}, -1);
264 388 : const auto coft2 = at::cat({cof02, cof12, cof22}, -1);
265 388 : const auto coft = at::stack({coft0, coft1, coft2}, -2);
266 97 : const auto inv = coft / det.unsqueeze(-1);
267 194 : return Derived(inv, this->batch_sizes());
268 485 : }
269 :
270 : template <class Derived>
271 : Derived
272 198 : R2Base<Derived>::transpose() const
273 : {
274 198 : return TensorBase<Derived>::base_transpose(0, 1);
275 : }
276 :
277 : template <class Derived1, class Derived2, typename, typename>
278 : Vec
279 177 : operator*(const Derived1 & A, const Derived2 & b)
280 : {
281 177 : neml_assert_batch_broadcastable_dbg(A, b);
282 708 : return Vec(at::einsum("...ik,...k", {A, b}));
283 177 : }
284 :
285 : template <class Derived1, class Derived2, typename, typename>
286 : R2
287 1479 : operator*(const Derived1 & A, const Derived2 & B)
288 : {
289 1479 : neml_assert_broadcastable_dbg(A, B);
290 5916 : return R2(at::einsum("...ik,...kj", {A, B}));
291 1479 : }
292 :
293 : // template instantiation
294 :
295 : // derived classes
296 : template class R2Base<R2>;
297 :
298 : // products
299 : template Vec operator*(const R2 & A, const Vec & b);
300 : template R2 operator*(const R2 & A, const R2 & B);
301 : } // namespace neml2
|