LCOV - code coverage report
Current view: top level - tensors - Rot.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 82.0 % 150 123
Test Date: 2025-06-29 01:25:44 Functions: 81.0 % 21 17

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/tensors/Rot.h"
      26              : #include "neml2/tensors/Scalar.h"
      27              : #include "neml2/tensors/Vec.h"
      28              : #include "neml2/tensors/R2.h"
      29              : #include "neml2/tensors/SR2.h"
      30              : #include "neml2/tensors/R3.h"
      31              : #include "neml2/tensors/R4.h"
      32              : #include "neml2/tensors/SSR4.h"
      33              : #include "neml2/tensors/WR2.h"
      34              : #include "neml2/tensors/Quaternion.h"
      35              : #include "neml2/misc/assertions.h"
      36              : #include "neml2/tensors/functions/sqrt.h"
      37              : #include "neml2/tensors/functions/pow.h"
      38              : #include "neml2/tensors/functions/sin.h"
      39              : #include "neml2/tensors/functions/cos.h"
      40              : #include "neml2/tensors/functions/tan.h"
      41              : #include "neml2/tensors/functions/asin.h"
      42              : #include "neml2/tensors/functions/acos.h"
      43              : #include "neml2/tensors/functions/minimum.h"
      44              : #include "neml2/tensors/functions/deg2rad.h"
      45              : #include "neml2/tensors/functions/fmod.h"
      46              : #include "neml2/tensors/functions/stack.h"
      47              : #include "neml2/tensors/functions/diag_embed.h"
      48              : #include "neml2/tensors/functions/clip.h"
      49              : 
      50              : namespace neml2
      51              : {
      52           38 : Rot::Rot(const Vec & v)
      53           38 :   : Rot(Tensor(v))
      54              : {
      55           38 : }
      56              : 
      57              : Rot
      58            3 : Rot::identity(const TensorOptions & options)
      59              : {
      60            3 :   return Rot::zeros(options);
      61              : }
      62              : 
      63              : Rot
      64            6 : Rot::fill_euler_angles(const Vec & v,
      65              :                        const std::string & angle_convention,
      66              :                        const std::string & angle_type)
      67              : {
      68            6 :   auto m = v;
      69              : 
      70            6 :   if (angle_type == "degrees")
      71            3 :     m = neml2::deg2rad(v);
      72              :   else
      73            3 :     neml_assert(angle_type == "radians", "Rot angle_type must be either 'degrees' or 'radians'");
      74              : 
      75            6 :   if (angle_convention == "bunge")
      76              :   {
      77           10 :     m.base_index_put_({0}, neml2::fmod(m.base_index({0}) - M_PI / 2.0, 2.0 * M_PI));
      78           10 :     m.base_index_put_({1}, neml2::fmod(m.base_index({1}), M_PI));
      79           10 :     m.base_index_put_({2}, neml2::fmod(M_PI / 2.0 - m.base_index({2}), 2.0 * M_PI));
      80              :   }
      81            4 :   else if (angle_convention == "roe")
      82              :   {
      83            8 :     m.base_index_put_({2}, M_PI - m.base_index({2}));
      84              :   }
      85              :   else
      86            2 :     neml_assert(angle_convention == "kocks", "Unknown Rot angle_convention " + angle_convention);
      87              : 
      88              :   // Make a rotation matrix
      89            6 :   auto M = R2(neml2::base_diag_embed(m));
      90           12 :   auto a = m.base_index({0});
      91           12 :   auto b = m.base_index({1});
      92           12 :   auto c = m.base_index({2});
      93           18 :   M.base_index_put_({0, 0},
      94           12 :                     -neml2::sin(c) * neml2::sin(a) - neml2::cos(c) * neml2::cos(a) * neml2::cos(b));
      95           18 :   M.base_index_put_({0, 1},
      96           12 :                     neml2::sin(c) * neml2::cos(a) - neml2::cos(c) * neml2::sin(a) * neml2::cos(b));
      97           24 :   M.base_index_put_({0, 2}, neml2::cos(c) * neml2::sin(b));
      98           18 :   M.base_index_put_({1, 0},
      99           12 :                     neml2::cos(c) * neml2::sin(a) - neml2::sin(c) * neml2::cos(a) * neml2::cos(b));
     100           18 :   M.base_index_put_({1, 1},
     101           12 :                     -neml2::cos(c) * neml2::cos(a) - neml2::sin(c) * neml2::sin(a) * neml2::cos(b));
     102           24 :   M.base_index_put_({1, 2}, neml2::sin(c) * neml2::sin(b));
     103           24 :   M.base_index_put_({2, 0}, neml2::cos(a) * neml2::sin(b));
     104           24 :   M.base_index_put_({2, 1}, neml2::sin(a) * neml2::sin(b));
     105           24 :   M.base_index_put_({2, 2}, neml2::cos(b));
     106              : 
     107              :   // Convert from matrix to vector
     108           12 :   return fill_matrix(M);
     109           94 : }
     110              : 
     111              : Rot
     112            6 : Rot::fill_matrix(const R2 & M)
     113              : {
     114              :   // Get the angle
     115            6 :   auto trace = M(0, 0) + M(1, 1) + M(2, 2);
     116            6 :   auto theta = neml2::acos((trace - 1.0) / 2.0);
     117              : 
     118              :   // Get the standard Rod. parameters
     119            6 :   auto scale = neml2::tan(theta / 2.0) / (2.0 * neml2::sin(theta));
     120           24 :   scale.index_put_({theta == 0}, 0.0);
     121            6 :   auto rx = (M(2, 1) - M(1, 2)) * scale;
     122            6 :   auto ry = (M(0, 2) - M(2, 0)) * scale;
     123            6 :   auto rz = (M(1, 0) - M(0, 1)) * scale;
     124              : 
     125           12 :   return fill_rodrigues(rx, ry, rz);
     126           18 : }
     127              : 
     128              : Rot
     129            6 : Rot::fill_rodrigues(const Scalar & rx, const Scalar & ry, const Scalar & rz)
     130              : {
     131              :   // Get the modified Rod. parameters
     132            6 :   auto ns = rx * rx + ry * ry + rz * rz;
     133            6 :   auto f = neml2::sqrt(ns + 1) + 1;
     134              : 
     135              :   // Stack and return
     136           36 :   return Rot(base_stack({rx / f, ry / f, rz / f}));
     137           12 : }
     138              : 
     139              : Rot
     140            0 : Rot::fill_random(unsigned int n)
     141              : {
     142            0 :   auto u0 = Scalar(at::rand({n}, default_tensor_options()));
     143            0 :   auto u1 = Scalar(at::rand({n}, default_tensor_options()));
     144            0 :   auto u2 = Scalar(at::rand({n}, default_tensor_options()));
     145              : 
     146            0 :   auto w = neml2::sqrt(1.0 - u0) * neml2::sin(2.0 * M_PI * u1);
     147            0 :   auto x = neml2::sqrt(1.0 - u0) * neml2::cos(2.0 * M_PI * u1);
     148            0 :   auto y = neml2::sqrt(u0) * neml2::sin(2.0 * M_PI * u2);
     149            0 :   auto z = neml2::sqrt(u0) * neml2::cos(2.0 * M_PI * u2);
     150              : 
     151            0 :   auto quats = Quaternion(base_stack({w, x, y, z}));
     152              : 
     153            0 :   return fill_matrix(quats.to_R2());
     154            0 : }
     155              : 
     156              : Rot
     157            1 : Rot::rotation_from_to(const Vec & v1, const Vec & v2)
     158              : {
     159            1 :   auto n = v1.cross(v2);
     160            1 :   auto c = v1.dot(v2);
     161              : 
     162            1 :   auto srp = n / (1.0 + c);
     163            1 :   auto nsrp = srp.norm_sq();
     164              : 
     165            2 :   return srp / (neml2::sqrt(1.0 + nsrp) + 1.0);
     166            1 : }
     167              : 
     168              : Rot
     169            3 : Rot::from_axis_angle(const Vec & n, const Scalar & theta)
     170              : {
     171            3 :   auto nn = n / n.norm();
     172            3 :   auto t = neml2::tan(theta / 4.0);
     173              : 
     174            6 :   return nn * t;
     175            3 : }
     176              : 
     177              : Rot
     178            1 : Rot::from_axis_angle_standard(const Vec & n, const Scalar & theta)
     179              : {
     180            1 :   auto vn = from_axis_angle(n, theta * 2);
     181            2 :   return vn / (neml2::sqrt(1.0 + vn.norm_sq()) + 1.0);
     182            1 : }
     183              : 
     184              : Rot
     185            2 : Rot::inverse() const
     186              : {
     187            2 :   return -(*this);
     188              : }
     189              : 
     190              : R2
     191          252 : Rot::euler_rodrigues() const
     192              : {
     193          252 :   auto rr = norm_sq();
     194          252 :   auto E = R3::levi_civita(options());
     195          252 :   auto W = R2::skew(*this);
     196              : 
     197          504 :   return 1.0 / neml2::pow(1 + rr, 2.0) *
     198         1008 :          (neml2::pow(1 + rr, 2.0) * R2::identity(options()) + 4 * (1.0 - rr) * W + 8.0 * W * W);
     199          252 : }
     200              : 
     201              : R3
     202           37 : Rot::deuler_rodrigues() const
     203              : {
     204           37 :   auto rr = norm_sq();
     205           37 :   auto I = R2::identity(options());
     206           37 :   auto E = R3::levi_civita(options());
     207           37 :   auto W = R2::skew(*this);
     208              : 
     209          185 :   return 8.0 * (rr - 3.0) / neml2::pow(1.0 + rr, 3.0) * R3(at::einsum("...ij,...k", {W, *this})) -
     210          259 :          32.0 / neml2::pow(1 + rr, 3.0) * R3(at::einsum("...ij,...k", {(W * W), *this})) -
     211          148 :          4.0 * (1 - rr) / neml2::pow(1.0 + rr, 2.0) * R3(at::einsum("...kij->...ijk", {E})) -
     212           74 :          8.0 / neml2::pow(1.0 + rr, 2.0) *
     213          370 :              R3(at::einsum("...kim,...mj", {E, W}) + at::einsum("...im,...kmj", {W, E}));
     214          222 : }
     215              : 
     216              : Rot
     217           30 : Rot::rotate(const Rot & r) const
     218              : {
     219           30 :   return r * (*this);
     220              : }
     221              : 
     222              : R2
     223           11 : Rot::drotate(const Rot & r) const
     224              : {
     225           11 :   auto r1 = *this;
     226              : 
     227           11 :   auto rr1 = r1.norm_sq();
     228           11 :   auto rr2 = r.norm_sq();
     229           11 :   auto d = 1.0 + rr1 * rr2 - 2 * Vec(r1).dot(r);
     230           11 :   auto r3 = rotate(r);
     231           11 :   auto I = R2::identity(options());
     232              : 
     233           22 :   return 1.0 / d *
     234           22 :          (-Vec(r3).outer(2 * rr1 * Vec(r) - 2.0 * Vec(r1)) - 2 * Vec(r1).outer(Vec(r)) +
     235           66 :           (1 - rr1) * I - 2 * R2::skew(r1));
     236           11 : }
     237              : 
     238              : R2
     239            3 : Rot::drotate_self(const Rot & r) const
     240              : {
     241            3 :   auto r2 = *this;
     242              : 
     243            3 :   auto rr1 = r.norm_sq();
     244            3 :   auto rr2 = r2.norm_sq();
     245            3 :   auto d = 1.0 + rr1 * rr2 - 2 * Vec(r).dot(r2);
     246            3 :   auto r3 = rotate(r);
     247            3 :   auto I = R2::identity(options());
     248              : 
     249            6 :   return 1.0 / d *
     250            6 :          (-Vec(r3).outer(2 * rr1 * Vec(r2) - 2.0 * Vec(r)) - 2 * Vec(r).outer(Vec(r2)) +
     251           18 :           (1 - rr1) * I + 2 * R2::skew(r));
     252            3 : }
     253              : 
     254              : Rot
     255            7 : Rot::shadow() const
     256              : {
     257            7 :   return -*this / norm_sq();
     258              : }
     259              : 
     260              : R2
     261            2 : Rot::dshadow() const
     262              : {
     263            2 :   auto ns = norm_sq();
     264            4 :   return (2.0 / ns * Vec(*this).outer(*this) - R2::identity(options())) / ns;
     265            2 : }
     266              : 
     267              : Scalar
     268            0 : Rot::dist(const Rot & r2) const
     269              : {
     270            0 :   const auto r1s = this->shadow();
     271            0 :   const auto r2s = r2.shadow();
     272            0 :   const auto d_r1_r2 = this->gdist(r2);
     273            0 :   const auto d_r1_r2s = this->gdist(r2s);
     274            0 :   const auto d_r2s_r2 = r2s.gdist(r2);
     275            0 :   const auto d_r2s_r1s = r2s.gdist(r1s);
     276              : 
     277            0 :   return neml2::minimum(neml2::minimum(neml2::minimum(d_r1_r2, d_r1_r2s), d_r2s_r2), d_r2s_r1s);
     278            0 : }
     279              : 
     280              : Scalar
     281            0 : Rot::gdist(const Rot & r) const
     282              : {
     283            0 :   return 4.0 * neml2::asin(neml2::clip((*this - r).norm() /
     284            0 :                                            neml2::sqrt((1.0 + norm_sq()) * (1.0 + r.norm_sq())),
     285            0 :                                        -Scalar(1.0, r.options()),
     286            0 :                                        Scalar(1.0, r.options())));
     287              : }
     288              : 
     289              : Scalar
     290            0 : Rot::dV() const
     291              : {
     292            0 :   return 8.0 / M_PI * neml2::pow(1.0 + norm_sq(), -3.0);
     293              : }
     294              : 
     295              : Rot
     296           34 : operator*(const Rot & r1, const Rot & r2)
     297              : {
     298           34 :   auto rr1 = r1.norm_sq();
     299           34 :   auto rr2 = r2.norm_sq();
     300              : 
     301           68 :   return Rot((1 - rr2) * Vec(r1) + (1.0 - rr1) * Vec(r2) - 2.0 * Vec(r2).cross(r1)) /
     302          136 :          (1.0 + rr1 * rr2 - 2 * Vec(r1).dot(r2));
     303           34 : }
     304              : 
     305              : } // namemspace neml2
        

Generated by: LCOV version 2.0-1