Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/tensors/R2Base.h"
26 : #include "neml2/tensors/R2.h"
27 : #include "neml2/tensors/Scalar.h"
28 : #include "neml2/tensors/Vec.h"
29 : #include "neml2/tensors/SR2.h"
30 : #include "neml2/tensors/R3.h"
31 : #include "neml2/tensors/R4.h"
32 : #include "neml2/tensors/Rot.h"
33 : #include "neml2/tensors/WR2.h"
34 : #include "neml2/tensors/assertions.h"
35 : #include "neml2/tensors/functions/stack.h"
36 : #include "neml2/tensors/functions/sum.h"
37 :
38 : namespace neml2
39 : {
40 : template <class Derived>
41 : Derived
42 26 : R2Base<Derived>::fill(const CScalar & a, const TensorOptions & options)
43 : {
44 26 : return R2Base<Derived>::fill(Scalar(a, options));
45 : }
46 :
47 : template <class Derived>
48 : Derived
49 51 : R2Base<Derived>::fill(const Scalar & a)
50 : {
51 51 : auto zero = Scalar::zeros_like(a);
52 306 : return Derived(base_stack({base_stack({a, zero, zero}, -1),
53 204 : base_stack({zero, a, zero}, -1),
54 204 : base_stack({zero, zero, a}, -1)},
55 306 : -2));
56 765 : }
57 :
58 : template <class Derived>
59 : Derived
60 2 : R2Base<Derived>::fill(const CScalar & a11,
61 : const CScalar & a22,
62 : const CScalar & a33,
63 : const TensorOptions & options)
64 : {
65 2 : return R2Base<Derived>::fill(Scalar(a11, options), Scalar(a22, options), Scalar(a33, options));
66 : }
67 :
68 : template <class Derived>
69 : Derived
70 5 : R2Base<Derived>::fill(const Scalar & a11, const Scalar & a22, const Scalar & a33)
71 : {
72 5 : auto zero = Scalar::zeros_like(a11);
73 30 : return Derived(base_stack({base_stack({a11, zero, zero}, -1),
74 20 : base_stack({zero, a22, zero}, -1),
75 20 : base_stack({zero, zero, a33}, -1)},
76 30 : -2));
77 75 : }
78 :
79 : template <class Derived>
80 : Derived
81 2 : R2Base<Derived>::fill(const CScalar & a11,
82 : const CScalar & a22,
83 : const CScalar & a33,
84 : const CScalar & a23,
85 : const CScalar & a13,
86 : const CScalar & a12,
87 : const TensorOptions & options)
88 : {
89 4 : return R2Base<Derived>::fill(Scalar(a11, options),
90 4 : Scalar(a22, options),
91 4 : Scalar(a33, options),
92 4 : Scalar(a23, options),
93 4 : Scalar(a13, options),
94 6 : Scalar(a12, options));
95 : }
96 :
97 : template <class Derived>
98 : Derived
99 5 : R2Base<Derived>::fill(const Scalar & a11,
100 : const Scalar & a22,
101 : const Scalar & a33,
102 : const Scalar & a23,
103 : const Scalar & a13,
104 : const Scalar & a12)
105 : {
106 30 : return Derived(base_stack({base_stack({a11, a12, a13}, -1),
107 20 : base_stack({a12, a22, a23}, -1),
108 20 : base_stack({a13, a23, a33}, -1)},
109 30 : -2));
110 75 : }
111 :
112 : template <class Derived>
113 : Derived
114 35 : R2Base<Derived>::fill(const CScalar & a11,
115 : const CScalar & a12,
116 : const CScalar & a13,
117 : const CScalar & a21,
118 : const CScalar & a22,
119 : const CScalar & a23,
120 : const CScalar & a31,
121 : const CScalar & a32,
122 : const CScalar & a33,
123 : const TensorOptions & options)
124 : {
125 70 : return R2Base<Derived>::fill(Scalar(a11, options),
126 70 : Scalar(a12, options),
127 70 : Scalar(a13, options),
128 70 : Scalar(a21, options),
129 70 : Scalar(a22, options),
130 70 : Scalar(a23, options),
131 70 : Scalar(a31, options),
132 70 : Scalar(a32, options),
133 105 : Scalar(a33, options));
134 : }
135 :
136 : template <class Derived>
137 : Derived
138 143 : R2Base<Derived>::fill(const Scalar & a11,
139 : const Scalar & a12,
140 : const Scalar & a13,
141 : const Scalar & a21,
142 : const Scalar & a22,
143 : const Scalar & a23,
144 : const Scalar & a31,
145 : const Scalar & a32,
146 : const Scalar & a33)
147 : {
148 858 : return Derived(base_stack({base_stack({a11, a12, a13}, -1),
149 572 : base_stack({a21, a22, a23}, -1),
150 572 : base_stack({a31, a32, a33}, -1)},
151 858 : -2));
152 2145 : }
153 :
154 : template <class Derived>
155 : Derived
156 305 : R2Base<Derived>::skew(const Vec & v)
157 : {
158 305 : const auto z = Scalar::zeros_like(v(0));
159 1830 : return Derived(base_stack({base_stack({z, -v(2), v(1)}, -1),
160 1220 : base_stack({v(2), z, -v(0)}, -1),
161 1220 : base_stack({-v(1), v(0), z}, -1)},
162 1830 : -2));
163 4575 : }
164 :
165 : template <class Derived>
166 : Derived
167 381 : R2Base<Derived>::identity(const TensorOptions & options)
168 : {
169 381 : return Derived(at::eye(3, options), 0);
170 : }
171 :
172 : template <class Derived>
173 : Derived
174 18 : R2Base<Derived>::rotate(const Rot & r) const
175 : {
176 18 : return rotate(r.euler_rodrigues());
177 : }
178 :
179 : template <class Derived>
180 : Derived
181 43 : R2Base<Derived>::rotate(const R2 & R) const
182 : {
183 43 : return R * R2(*this) * R.transpose();
184 : }
185 :
186 : template <class Derived>
187 : R3
188 11 : R2Base<Derived>::drotate(const Rot & r) const
189 : {
190 11 : auto R = r.euler_rodrigues();
191 11 : auto F = r.deuler_rodrigues();
192 :
193 44 : return R3(at::einsum("...itl,...tm,...jm", {F, *this, R}) +
194 110 : at::einsum("...ik,...kt,...jtl", {R, *this, F}));
195 33 : }
196 :
197 : template <class Derived>
198 : R4
199 9 : R2Base<Derived>::drotate(const R2 & R) const
200 : {
201 9 : auto I = R2::identity(R.options());
202 27 : return R4(at::einsum("...ik,...jl", {I, R * this->transpose()}) +
203 90 : at::einsum("...jk,...il", {I, R * R2(*this)}));
204 45 : }
205 :
206 : template <class Derived>
207 : Scalar
208 81 : R2Base<Derived>::operator()(Size i, Size j) const
209 : {
210 243 : return this->base_index({i, j});
211 81 : }
212 :
213 : template <class Derived>
214 : Vec
215 0 : R2Base<Derived>::row(Size i) const
216 : {
217 0 : return Vec(this->base_index({i, indexing::Slice()}), this->batch_sizes());
218 0 : }
219 :
220 : template <class Derived>
221 : Vec
222 60 : R2Base<Derived>::col(Size i) const
223 : {
224 300 : return Vec(this->base_index({indexing::Slice(), i}), this->batch_sizes());
225 120 : }
226 :
227 : template <class Derived>
228 : Scalar
229 4 : R2Base<Derived>::det() const
230 : {
231 4 : const auto comps = at::split(this->base_flatten(), 1, -1);
232 4 : const auto & a = comps[0];
233 4 : const auto & b = comps[1];
234 4 : const auto & c = comps[2];
235 4 : const auto & d = comps[3];
236 4 : const auto & e = comps[4];
237 4 : const auto & f = comps[5];
238 4 : const auto & g = comps[6];
239 4 : const auto & h = comps[7];
240 4 : const auto & i = comps[8];
241 4 : const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
242 8 : return Scalar(det.reshape(this->batch_sizes().concrete()), this->batch_sizes());
243 4 : }
244 :
245 : template <class Derived>
246 : Scalar
247 4 : R2Base<Derived>::inner(const R2 & other) const
248 : {
249 4 : return base_sum(this->base_flatten() * other.base_flatten());
250 : }
251 :
252 : template <class Derived>
253 : R4
254 15 : R2Base<Derived>::outer(const R2 & other) const
255 : {
256 15 : return this->base_unsqueeze(-1).base_unsqueeze(-1) * other.base_unsqueeze(0).base_unsqueeze(0);
257 : }
258 :
259 : template <class Derived>
260 : Derived
261 97 : R2Base<Derived>::inverse() const
262 : {
263 97 : const auto comps = at::split(this->base_flatten(), 1, -1);
264 97 : const auto & a = comps[0];
265 97 : const auto & b = comps[1];
266 97 : const auto & c = comps[2];
267 97 : const auto & d = comps[3];
268 97 : const auto & e = comps[4];
269 97 : const auto & f = comps[5];
270 97 : const auto & g = comps[6];
271 97 : const auto & h = comps[7];
272 97 : const auto & i = comps[8];
273 97 : const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
274 97 : const auto cof00 = e * i - h * f;
275 97 : const auto cof01 = -(d * i - g * f);
276 97 : const auto cof02 = d * h - g * e;
277 97 : const auto cof10 = -(b * i - h * c);
278 97 : const auto cof11 = a * i - g * c;
279 97 : const auto cof12 = -(a * h - g * b);
280 97 : const auto cof20 = b * f - e * c;
281 97 : const auto cof21 = -(a * f - d * c);
282 97 : const auto cof22 = a * e - d * b;
283 388 : const auto coft0 = at::cat({cof00, cof10, cof20}, -1);
284 388 : const auto coft1 = at::cat({cof01, cof11, cof21}, -1);
285 388 : const auto coft2 = at::cat({cof02, cof12, cof22}, -1);
286 388 : const auto coft = at::stack({coft0, coft1, coft2}, -2);
287 97 : const auto inv = coft / det.unsqueeze(-1);
288 194 : return Derived(inv, this->batch_sizes());
289 485 : }
290 :
291 : template <class Derived>
292 : Derived
293 262 : R2Base<Derived>::transpose() const
294 : {
295 262 : return TensorBase<Derived>::base_transpose(0, 1);
296 : }
297 :
298 : template <class Derived1, class Derived2, typename, typename>
299 : Vec
300 185 : operator*(const Derived1 & A, const Derived2 & b)
301 : {
302 185 : neml_assert_batch_broadcastable_dbg(A, b);
303 740 : return Vec(at::einsum("...ik,...k", {A, b}));
304 185 : }
305 :
306 : template <class Derived1, class Derived2, typename, typename>
307 : R2
308 1479 : operator*(const Derived1 & A, const Derived2 & B)
309 : {
310 1479 : neml_assert_broadcastable_dbg(A, B);
311 5916 : return R2(at::einsum("...ik,...kj", {A, B}));
312 1479 : }
313 :
314 : // template instantiation
315 :
316 : // derived classes
317 : template class R2Base<R2>;
318 :
319 : // products
320 : template Vec operator*(const R2 & A, const Vec & b);
321 : template R2 operator*(const R2 & A, const R2 & B);
322 : } // namespace neml2
|