Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/tensors/SR2.h"
26 : #include "neml2/tensors/Scalar.h"
27 : #include "neml2/tensors/R2.h"
28 : #include "neml2/tensors/R3.h"
29 : #include "neml2/tensors/SFR3.h"
30 : #include "neml2/tensors/SSR4.h"
31 : #include "neml2/tensors/SFFR4.h"
32 : #include "neml2/tensors/Rot.h"
33 : #include "neml2/tensors/SWR4.h"
34 : #include "neml2/tensors/WR2.h"
35 : #include "neml2/tensors/R4.h"
36 : #include "neml2/tensors/assertions.h"
37 :
38 : #include "neml2/tensors/mandel_notation.h"
39 : #include "neml2/tensors/functions/sqrt.h"
40 : #include "neml2/tensors/functions/sum.h"
41 : #include "neml2/tensors/functions/stack.h"
42 : #include "neml2/tensors/functions/linalg/vecdot.h"
43 :
44 : namespace neml2
45 : {
46 :
47 153 : SR2::SR2(const R2 & T)
48 153 : : SR2(full_to_mandel((T + T.transpose()) / 2.0))
49 : {
50 153 : }
51 :
52 : SR2
53 3 : SR2::fill(const CScalar & a, const TensorOptions & options)
54 : {
55 3 : return SR2::fill(Scalar(a, options));
56 : }
57 :
58 : SR2
59 143 : SR2::fill(const Scalar & a)
60 : {
61 143 : auto zero = Scalar::zeros_like(a);
62 1287 : return SR2(base_stack({a, a, a, zero, zero, zero}, -1));
63 286 : }
64 :
65 : SR2
66 18 : SR2::fill(const CScalar & a11,
67 : const CScalar & a22,
68 : const CScalar & a33,
69 : const TensorOptions & options)
70 : {
71 18 : return SR2::fill(Scalar(a11, options), Scalar(a22, options), Scalar(a33, options));
72 : }
73 :
74 : SR2
75 49 : SR2::fill(const Scalar & a11, const Scalar & a22, const Scalar & a33)
76 : {
77 49 : auto zero = Scalar::zeros_like(a11);
78 441 : return SR2(base_stack({a11, a22, a33, zero, zero, zero}, -1));
79 98 : }
80 :
81 : SR2
82 24 : SR2::fill(const CScalar & a11,
83 : const CScalar & a22,
84 : const CScalar & a33,
85 : const CScalar & a23,
86 : const CScalar & a13,
87 : const CScalar & a12,
88 : const TensorOptions & options)
89 : {
90 48 : return SR2::fill(Scalar(a11, options),
91 48 : Scalar(a22, options),
92 48 : Scalar(a33, options),
93 48 : Scalar(a23, options),
94 48 : Scalar(a13, options),
95 72 : Scalar(a12, options));
96 : }
97 :
98 : SR2
99 124 : SR2::fill(const Scalar & a11,
100 : const Scalar & a22,
101 : const Scalar & a33,
102 : const Scalar & a23,
103 : const Scalar & a13,
104 : const Scalar & a12)
105 : {
106 1116 : return SR2(base_stack(
107 620 : {a11, a22, a33, mandel_factor(3) * a23, mandel_factor(4) * a13, mandel_factor(5) * a12}, -1));
108 496 : }
109 :
110 : SR2
111 59 : SR2::identity(const TensorOptions & options)
112 : {
113 413 : return SR2::create({1, 1, 1, 0, 0, 0}, options);
114 59 : }
115 :
116 : SSR4
117 59 : SR2::identity_map(const TensorOptions & options)
118 : {
119 59 : return SSR4::identity_sym(options);
120 : }
121 :
122 : SR2
123 8 : SR2::rotate(const Rot & r) const
124 : {
125 8 : return R2(*this).rotate(r);
126 : }
127 :
128 : SR2
129 22 : SR2::rotate(const R2 & R) const
130 : {
131 22 : return R2(*this).rotate(R);
132 : }
133 :
134 : SFR3
135 4 : SR2::drotate(const Rot & r) const
136 : {
137 4 : auto dR = R2(*this).drotate(r);
138 8 : return full_to_mandel(dR);
139 4 : }
140 :
141 : SFFR4
142 8 : SR2::drotate(const R2 & R) const
143 : {
144 8 : auto dR = R2(*this).drotate(R);
145 16 : return full_to_mandel(dR);
146 8 : }
147 :
148 : Scalar
149 9 : SR2::operator()(Size i, Size j) const
150 : {
151 9 : Size a = mandel_reverse_index[i][j];
152 27 : return base_index({a}) / mandel_factor(a);
153 9 : }
154 :
155 : Scalar
156 190 : SR2::tr() const
157 : {
158 950 : return Scalar(base_sum(base_index({indexing::Slice(0, 3)}), -1), batch_sizes());
159 380 : }
160 :
161 : SR2
162 125 : SR2::vol() const
163 : {
164 250 : return SR2::fill(tr()) / 3;
165 : }
166 :
167 : SR2
168 97 : SR2::dev() const
169 : {
170 97 : return *this - vol();
171 : }
172 :
173 : Scalar
174 85 : SR2::inner(const SR2 & other) const
175 : {
176 85 : return linalg::vecdot(*this, other);
177 : }
178 :
179 : Scalar
180 61 : SR2::norm_sq() const
181 : {
182 61 : return inner(*this);
183 : }
184 :
185 : Scalar
186 59 : SR2::norm(const CScalar & eps) const
187 : {
188 59 : return neml2::sqrt(norm_sq() + eps);
189 : }
190 :
191 : SSR4
192 38 : SR2::outer(const SR2 & other) const
193 : {
194 38 : neml_assert_broadcastable_dbg(*this, other);
195 190 : return SSR4(at::einsum("...i,...j", {*this, other}), utils::broadcast_batch_dim(*this, other));
196 38 : }
197 :
198 : Scalar
199 2 : SR2::det() const
200 : {
201 2 : const auto comps = at::split(*this, 1, -1);
202 2 : const auto & a = comps[0];
203 2 : const auto & e = comps[1];
204 2 : const auto & i = comps[2];
205 2 : const auto f = comps[3] / mandel_factor(3);
206 2 : const auto c = comps[4] / mandel_factor(4);
207 2 : const auto b = comps[5] / mandel_factor(5);
208 2 : const auto det = a * (e * i - f * f) - b * (b * i - c * f) + c * (b * f - e * c);
209 6 : return Scalar(det.reshape(batch_sizes().concrete()), batch_sizes());
210 2 : }
211 :
212 : SR2
213 2 : SR2::inverse() const
214 : {
215 2 : const auto comps = at::split(*this, 1, -1);
216 2 : const auto & a = comps[0];
217 2 : const auto & e = comps[1];
218 2 : const auto & i = comps[2];
219 2 : const auto f = comps[3] / mandel_factor(3);
220 2 : const auto c = comps[4] / mandel_factor(4);
221 2 : const auto b = comps[5] / mandel_factor(5);
222 2 : const auto det = a * (e * i - f * f) - b * (b * i - c * f) + c * (b * f - e * c);
223 2 : const auto cof00 = e * i - f * f;
224 2 : const auto cof01 = -(b * i - c * f);
225 2 : const auto cof02 = b * f - c * e;
226 2 : const auto cof11 = a * i - c * c;
227 2 : const auto cof12 = -(a * f - c * b);
228 2 : const auto cof22 = a * e - b * b;
229 14 : const auto cof = at::cat({cof00,
230 : cof11,
231 : cof22,
232 : mandel_factor(3) * cof12,
233 : mandel_factor(4) * cof02,
234 : mandel_factor(5) * cof01},
235 16 : -1);
236 2 : const auto inv = cof / det;
237 4 : return SR2(inv, batch_sizes());
238 6 : }
239 :
240 : SR2
241 2 : SR2::transpose() const
242 : {
243 2 : return SR2(*this, batch_sizes());
244 : }
245 :
246 : } // namespace neml2
|