NEML2 2.0.0
Loading...
Searching...
No Matches
math.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/tensors/Tensor.h"
28
29namespace neml2
30{
31class SR2;
32class WR2;
33class SWR4;
34class SSR4;
35class WSR4;
36namespace math
37{
38constexpr Real eps = std::numeric_limits<at::scalar_value_type<Real>::type>::epsilon();
39
40constexpr Real sqrt2 = 1.4142135623730951;
41constexpr Real invsqrt2 = 0.7071067811865475;
42
43constexpr std::array<std::array<Size, 3>, 3> mandel_reverse_index{
44 std::array<Size, 3>{0, 5, 4}, std::array<Size, 3>{5, 1, 3}, std::array<Size, 3>{4, 3, 2}};
45constexpr std::array<std::array<Size, 2>, 6> mandel_index{std::array<Size, 2>{0, 0},
46 std::array<Size, 2>{1, 1},
47 std::array<Size, 2>{2, 2},
48 std::array<Size, 2>{1, 2},
49 std::array<Size, 2>{0, 2},
50 std::array<Size, 2>{0, 1}};
51
52constexpr std::array<std::array<Size, 3>, 3> skew_reverse_index{
53 std::array<Size, 3>{0, 2, 1}, std::array<Size, 3>{2, 0, 0}, std::array<Size, 3>{1, 0, 0}};
54constexpr std::array<std::array<Real, 3>, 3> skew_factor{std::array<Real, 3>{0.0, -1.0, 1.0},
55 std::array<Real, 3>{1.0, 0.0, -1.0},
56 std::array<Real, 3>{-1.0, 1.0, 0.0}};
57
58inline constexpr Real
60{
61 return i < 3 ? 1.0 : sqrt2;
62}
63
73{
75
76 // Get the global constants
77 static ConstantTensors & get();
78
79 static const torch::Tensor & full_to_mandel_map();
80 static const torch::Tensor & mandel_to_full_map();
81 static const torch::Tensor & full_to_mandel_factor();
82 static const torch::Tensor & mandel_to_full_factor();
83 static const torch::Tensor & full_to_skew_map();
84 static const torch::Tensor & skew_to_full_map();
85 static const torch::Tensor & full_to_skew_factor();
86 static const torch::Tensor & skew_to_full_factor();
87
88private:
89 torch::Tensor _full_to_mandel_map;
90 torch::Tensor _mandel_to_full_map;
91 torch::Tensor _full_to_mandel_factor;
92 torch::Tensor _mandel_to_full_factor;
93 torch::Tensor _full_to_skew_map;
94 torch::Tensor _skew_to_full_map;
95 torch::Tensor _full_to_skew_factor;
96 torch::Tensor _skew_to_full_factor;
97};
98
120Tensor full_to_reduced(const Tensor & full,
121 const torch::Tensor & rmap,
122 const torch::Tensor & rfactors,
123 Size dim = 0);
124
137 const torch::Tensor & rmap,
138 const torch::Tensor & rfactors,
139 Size dim = 0);
140
157Tensor full_to_mandel(const Tensor & full, Size dim = 0);
158
168Tensor mandel_to_full(const Tensor & mandel, Size dim = 0);
169
187Tensor full_to_skew(const Tensor & full, Size dim = 0);
188
198Tensor skew_to_full(const Tensor & skew, Size dim = 0);
199
217Tensor jacrev(const Tensor & y,
218 const Tensor & x,
219 bool retain_graph = false,
220 bool create_graph = false,
221 bool allow_unused = false);
222
223Tensor base_diag_embed(const Tensor & a, Size offset = 0, Size d1 = -2, Size d2 = -1);
224
226SR2 skew_and_sym_to_sym(const SR2 & e, const WR2 & w);
227
230
233
235WR2 multiply_and_make_skew(const SR2 & a, const SR2 & b);
236
239
242
243template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
244T
245batch_cat(const std::vector<T> & tensors, Size d = 0)
246{
247 neml_assert_dbg(!tensors.empty(), "batch_cat must be given at least one tensor");
248 std::vector<torch::Tensor> torch_tensors(tensors.begin(), tensors.end());
249 auto d2 = d >= 0 ? d : d - tensors.begin()->base_dim();
250 return T(torch::cat(torch_tensors, d2), tensors.begin()->batch_dim());
251}
252
253neml2::Tensor base_cat(const std::vector<Tensor> & tensors, Size d = 0);
254
255template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
256T
257batch_stack(const std::vector<T> & tensors, Size d = 0)
258{
259 neml_assert_dbg(!tensors.empty(), "batch_stack must be given at least one tensor");
260 std::vector<torch::Tensor> torch_tensors(tensors.begin(), tensors.end());
261 auto d2 = d >= 0 ? d : d - tensors.begin()->base_dim();
262 return T(torch::stack(torch_tensors, d2), tensors.begin()->batch_dim() + 1);
263}
264
265neml2::Tensor base_stack(const std::vector<Tensor> & tensors, Size d = 0);
266
267template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
268T
269batch_sum(const T & a, Size d = 0)
270{
271 neml_assert_dbg(a.batch_dim() > 0, "Must have a batch dimension to sum along");
272 auto d2 = d >= 0 ? d : d - a.base_dim();
273 return T(torch::sum(a, d2), a.batch_sizes().slice(0, -1));
274}
275
276neml2::Tensor base_sum(const Tensor & a, Size d);
277
278template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
279T
280batch_mean(const T & a, Size d = 0)
281{
282 neml_assert_dbg(a.batch_dim() > 0, "Must have a batch dimension to take average");
283 auto d2 = d >= 0 ? d : d - a.base_dim();
284 return T(torch::mean(a, d2), a.batch_sizes().slice(0, -1));
285}
286
287neml2::Tensor base_mean(const Tensor & a, Size d);
288
289template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
290T
291pow(const T & a, const Real & n)
292{
293 return T(torch::pow(a, n), a.batch_sizes());
294}
295
296Tensor pow(const Real & a, const Tensor & n);
297
298Tensor pow(const Tensor & a, const Tensor & n);
299
300template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
301T
302sign(const T & a)
303{
304 return T(torch::sign(a), a.batch_sizes());
305}
306
307template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
308T
309cosh(const T & a)
310{
311 return T(torch::cosh(a), a.batch_sizes());
312}
313
314template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
315T
316sinh(const T & a)
317{
318 return T(torch::sinh(a), a.batch_sizes());
319}
320
321template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
322T
323tanh(const T & a)
324{
325 return T(torch::tanh(a), a.batch_sizes());
326}
327
328template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
329T
330arccos(const T & a)
331{
332 return T(torch::arccos(a), a.batch_dim());
333}
334
335template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
336T
337arcsin(const T & a)
338{
339 return T(torch::arcsin(a), a.batch_dim());
340}
341
342template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
343T
344where(const torch::Tensor & condition, const T & a, const T & b)
345{
347 return T(torch::where(condition, a, b), broadcast_batch_dim(a, b));
348}
349
356template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
357T
358heaviside(const T & a)
359{
360 return (sign(a) + 1.0) / 2.0;
361}
362
363template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
364T
365macaulay(const T & a)
366{
367 return T(torch::Tensor(a) * torch::Tensor(heaviside(a)), a.batch_dim());
368}
369
370template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
371T
372dmacaulay(const T & a)
373{
374 return heaviside(a);
375}
376
377template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
378T
379sqrt(const T & a)
380{
381 return T(torch::sqrt(a), a.batch_sizes());
382}
383
384template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
385T
386exp(const T & a)
387{
388 return T(torch::exp(a), a.batch_sizes());
389}
390
391template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
392T
393abs(const T & a)
394{
395 return T(torch::abs(a), a.batch_sizes());
396}
397
398template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
399T
400diff(const T & a, Size n = 1, Size dim = -1)
401{
402 return T(torch::diff(a, n, dim), a.batch_sizes());
403}
404
405template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
406T
407batch_diag_embed(const T & a, Size offset = 0, Size d1 = -2, Size d2 = -1)
408{
409 return T(torch::diag_embed(
410 a, offset, d1 < 0 ? d1 - a.base_dim() : d1, d2 < 0 ? d2 - a.base_dim() : d2),
411 a.batch_dim() + 1);
412}
413
414template <class T, typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
415T
416log(const T & a)
417{
418 return T(torch::log(a), a.batch_sizes());
419}
420
421namespace linalg
422{
424Tensor vector_norm(const Tensor & v);
425
427Tensor inv(const Tensor & m);
428
430Tensor solve(const Tensor & A, const Tensor & B);
431
432std::tuple<Tensor, Tensor> lu_factor(const Tensor & A, bool pivot = true);
433
434Tensor lu_solve(const Tensor & LU,
435 const Tensor & pivots,
436 const Tensor & B,
437 bool left = true,
438 bool adjoint = false);
439} // namespace linalg
440} // namespace math
441} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:54
The symmetric second order tensor.
Definition SR2.h:46
The symmetric fourth order tensor, with symmetry in the first two dimensionss as well as in the last ...
Definition SSR4.h:45
The symmetric fourth order tensor, with symmetry in the first two dimensionss and skew-symmetry in th...
Definition SWR4.h:40
Definition Tensor.h:47
A skew-symmetric second order tensor, represented as an axial vector.
Definition WR2.h:43
The symmetric fourth order tensor, with skew symmetry in the first two dimensionss and symmetry in th...
Definition WSR4.h:40
Tensor solve(const Tensor &A, const Tensor &B)
Solve the linear system A X = B.
Definition math.cxx:374
Tensor vector_norm(const Tensor &v)
Vector norm of a vector. Falls back to math::abs is v is a Scalar.
Definition math.cxx:351
Tensor lu_solve(const Tensor &LU, const Tensor &pivots, const Tensor &B, bool left, bool adjoint)
Definition math.cxx:387
Tensor inv(const Tensor &m)
Inverse of a square matrix.
Definition math.cxx:368
std::tuple< Tensor, Tensor > lu_factor(const Tensor &A, bool pivot)
Definition math.cxx:380
Tensor full_to_reduced(const Tensor &full, const torch::Tensor &rmap, const torch::Tensor &rfactors, Size dim)
Generic function to reduce two axes to one with some map.
Definition math.cxx:111
T dmacaulay(const T &a)
Definition math.h:372
Tensor reduced_to_full(const Tensor &reduced, const torch::Tensor &rmap, const torch::Tensor &rfactors, Size dim)
Convert a Tensor from reduced notation to full notation.
Definition math.cxx:137
T batch_sum(const T &a, Size d=0)
Definition math.h:269
Tensor jacrev(const Tensor &y, const Tensor &x, bool retain_graph, bool create_graph, bool allow_unused)
Use automatic differentiation (AD) to calculate the derivatives of a Tensor w.r.t....
Definition math.cxx:200
T batch_mean(const T &a, Size d=0)
Definition math.h:280
T cosh(const T &a)
Definition math.h:309
constexpr Real eps
Definition math.h:38
SWR4 d_skew_and_sym_to_sym_d_skew(const SR2 &e)
Derivative of w_ik e_kj - e_ik w_kj wrt. w.
Definition math.cxx:268
SSR4 d_skew_and_sym_to_sym_d_sym(const WR2 &w)
Derivative of w_ik e_kj - e_ik w_kj wrt. e.
Definition math.cxx:259
T heaviside(const T &a)
Definition math.h:358
Tensor full_to_skew(const Tensor &full, Size dim)
Convert a Tensor from full notation to skew vector notation.
Definition math.cxx:180
Tensor base_cat(const std::vector< Tensor > &tensors, Size d)
Definition math.cxx:304
constexpr Real invsqrt2
Definition math.h:41
T arccos(const T &a)
Definition math.h:330
Tensor mandel_to_full(const Tensor &mandel, Size dim)
Convert a Tensor from Mandel notation to full notation.
Definition math.cxx:170
T arcsin(const T &a)
Definition math.h:337
constexpr std::array< std::array< Real, 3 >, 3 > skew_factor
Definition math.h:54
Tensor base_diag_embed(const Tensor &a, Size offset, Size d1, Size d2)
Definition math.cxx:240
T exp(const T &a)
Definition math.h:386
T log(const T &a)
Definition math.h:416
WSR4 d_multiply_and_make_skew_d_first(const SR2 &b)
Derivative of a_ik b_kj - b_ik a_kj wrt a.
Definition math.cxx:286
SR2 skew_and_sym_to_sym(const SR2 &e, const WR2 &w)
Product w_ik e_kj - e_ik w_kj with e SR2 and w WR2.
Definition math.cxx:249
T sinh(const T &a)
Definition math.h:316
T tanh(const T &a)
Definition math.h:323
WR2 multiply_and_make_skew(const SR2 &a, const SR2 &b)
Shortcut product a_ik b_kj - b_ik a_kj with both SR2.
Definition math.cxx:277
T batch_stack(const std::vector< T > &tensors, Size d=0)
Definition math.h:257
constexpr std::array< std::array< Size, 3 >, 3 > mandel_reverse_index
Definition math.h:43
T sqrt(const T &a)
Definition math.h:379
constexpr Real mandel_factor(Size i)
Definition math.h:59
constexpr std::array< std::array< Size, 2 >, 6 > mandel_index
Definition math.h:45
T diff(const T &a, Size n=1, Size dim=-1)
Definition math.h:400
T batch_diag_embed(const T &a, Size offset=0, Size d1=-2, Size d2=-1)
Definition math.h:407
T abs(const T &a)
Definition math.h:393
T batch_cat(const std::vector< T > &tensors, Size d=0)
Definition math.h:245
Tensor base_stack(const std::vector< Tensor > &tensors, Size d)
Definition math.cxx:313
WSR4 d_multiply_and_make_skew_d_second(const SR2 &a)
Derivative of a_ik b_kj - b_ik a_kj wrt b.
Definition math.cxx:295
Tensor base_sum(const Tensor &a, Size d)
Definition math.cxx:322
Tensor base_mean(const Tensor &a, Size d)
Definition math.cxx:329
Tensor pow(const Real &a, const Tensor &n)
Definition math.cxx:336
Tensor skew_to_full(const Tensor &skew, Size dim)
Convert a Tensor from skew vector notation to full notation.
Definition math.cxx:190
T where(const torch::Tensor &condition, const T &a, const T &b)
Definition math.h:344
constexpr std::array< std::array< Size, 3 >, 3 > skew_reverse_index
Definition math.h:52
Tensor full_to_mandel(const Tensor &full, Size dim)
Convert a Tensor from full notation to Mandel notation.
Definition math.cxx:160
T macaulay(const T &a)
Definition math.h:365
T sign(const T &a)
Definition math.h:302
constexpr Real sqrt2
Definition math.h:40
Definition CrossRef.cxx:31
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
Size broadcast_batch_dim(const T &...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(const T &...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
A helper class to hold static data of type torch::Tensor.
Definition math.h:73
static const torch::Tensor & skew_to_full_map()
Definition math.cxx:93
static const torch::Tensor & full_to_mandel_factor()
Definition math.cxx:75
static const torch::Tensor & full_to_skew_factor()
Definition math.cxx:99
static const torch::Tensor & full_to_skew_map()
Definition math.cxx:87
ConstantTensors()
Definition math.cxx:34
static ConstantTensors & get()
Definition math.cxx:56
static const torch::Tensor & mandel_to_full_factor()
Definition math.cxx:81
static const torch::Tensor & mandel_to_full_map()
Definition math.cxx:69
static const torch::Tensor & full_to_mandel_map()
Definition math.cxx:63
static const torch::Tensor & skew_to_full_factor()
Definition math.cxx:105