LCOV - code coverage report
Current view: top level - tensors - R2Base.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 100.0 % 128 128
Test Date: 2025-06-29 01:25:44 Functions: 100.0 % 21 21

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/tensors/R2Base.h"
      26              : #include "neml2/tensors/R2.h"
      27              : #include "neml2/tensors/Scalar.h"
      28              : #include "neml2/tensors/Vec.h"
      29              : #include "neml2/tensors/SR2.h"
      30              : #include "neml2/tensors/R3.h"
      31              : #include "neml2/tensors/R4.h"
      32              : #include "neml2/tensors/Rot.h"
      33              : #include "neml2/tensors/WR2.h"
      34              : #include "neml2/tensors/assertions.h"
      35              : #include "neml2/tensors/functions/stack.h"
      36              : #include "neml2/tensors/functions/sum.h"
      37              : 
      38              : namespace neml2
      39              : {
      40              : template <class Derived>
      41              : Derived
      42           26 : R2Base<Derived>::fill(const CScalar & a, const TensorOptions & options)
      43              : {
      44           26 :   return R2Base<Derived>::fill(Scalar(a, options));
      45              : }
      46              : 
      47              : template <class Derived>
      48              : Derived
      49           49 : R2Base<Derived>::fill(const Scalar & a)
      50              : {
      51           49 :   auto zero = Scalar::zeros_like(a);
      52          294 :   return Derived(base_stack({base_stack({a, zero, zero}, -1),
      53          196 :                              base_stack({zero, a, zero}, -1),
      54          196 :                              base_stack({zero, zero, a}, -1)},
      55          294 :                             -2));
      56          735 : }
      57              : 
      58              : template <class Derived>
      59              : Derived
      60            2 : R2Base<Derived>::fill(const CScalar & a11,
      61              :                       const CScalar & a22,
      62              :                       const CScalar & a33,
      63              :                       const TensorOptions & options)
      64              : {
      65            2 :   return R2Base<Derived>::fill(Scalar(a11, options), Scalar(a22, options), Scalar(a33, options));
      66              : }
      67              : 
      68              : template <class Derived>
      69              : Derived
      70            5 : R2Base<Derived>::fill(const Scalar & a11, const Scalar & a22, const Scalar & a33)
      71              : {
      72            5 :   auto zero = Scalar::zeros_like(a11);
      73           30 :   return Derived(base_stack({base_stack({a11, zero, zero}, -1),
      74           20 :                              base_stack({zero, a22, zero}, -1),
      75           20 :                              base_stack({zero, zero, a33}, -1)},
      76           30 :                             -2));
      77           75 : }
      78              : 
      79              : template <class Derived>
      80              : Derived
      81            2 : R2Base<Derived>::fill(const CScalar & a11,
      82              :                       const CScalar & a22,
      83              :                       const CScalar & a33,
      84              :                       const CScalar & a23,
      85              :                       const CScalar & a13,
      86              :                       const CScalar & a12,
      87              :                       const TensorOptions & options)
      88              : {
      89            4 :   return R2Base<Derived>::fill(Scalar(a11, options),
      90            4 :                                Scalar(a22, options),
      91            4 :                                Scalar(a33, options),
      92            4 :                                Scalar(a23, options),
      93            4 :                                Scalar(a13, options),
      94            6 :                                Scalar(a12, options));
      95              : }
      96              : 
      97              : template <class Derived>
      98              : Derived
      99            5 : R2Base<Derived>::fill(const Scalar & a11,
     100              :                       const Scalar & a22,
     101              :                       const Scalar & a33,
     102              :                       const Scalar & a23,
     103              :                       const Scalar & a13,
     104              :                       const Scalar & a12)
     105              : {
     106           30 :   return Derived(base_stack({base_stack({a11, a12, a13}, -1),
     107           20 :                              base_stack({a12, a22, a23}, -1),
     108           20 :                              base_stack({a13, a23, a33}, -1)},
     109           30 :                             -2));
     110           75 : }
     111              : 
     112              : template <class Derived>
     113              : Derived
     114           34 : R2Base<Derived>::fill(const CScalar & a11,
     115              :                       const CScalar & a12,
     116              :                       const CScalar & a13,
     117              :                       const CScalar & a21,
     118              :                       const CScalar & a22,
     119              :                       const CScalar & a23,
     120              :                       const CScalar & a31,
     121              :                       const CScalar & a32,
     122              :                       const CScalar & a33,
     123              :                       const TensorOptions & options)
     124              : {
     125           68 :   return R2Base<Derived>::fill(Scalar(a11, options),
     126           68 :                                Scalar(a12, options),
     127           68 :                                Scalar(a13, options),
     128           68 :                                Scalar(a21, options),
     129           68 :                                Scalar(a22, options),
     130           68 :                                Scalar(a23, options),
     131           68 :                                Scalar(a31, options),
     132           68 :                                Scalar(a32, options),
     133          102 :                                Scalar(a33, options));
     134              : }
     135              : 
     136              : template <class Derived>
     137              : Derived
     138          140 : R2Base<Derived>::fill(const Scalar & a11,
     139              :                       const Scalar & a12,
     140              :                       const Scalar & a13,
     141              :                       const Scalar & a21,
     142              :                       const Scalar & a22,
     143              :                       const Scalar & a23,
     144              :                       const Scalar & a31,
     145              :                       const Scalar & a32,
     146              :                       const Scalar & a33)
     147              : {
     148          840 :   return Derived(base_stack({base_stack({a11, a12, a13}, -1),
     149          560 :                              base_stack({a21, a22, a23}, -1),
     150          560 :                              base_stack({a31, a32, a33}, -1)},
     151          840 :                             -2));
     152         2100 : }
     153              : 
     154              : template <class Derived>
     155              : Derived
     156          305 : R2Base<Derived>::skew(const Vec & v)
     157              : {
     158          305 :   const auto z = Scalar::zeros_like(v(0));
     159         1830 :   return Derived(base_stack({base_stack({z, -v(2), v(1)}, -1),
     160         1220 :                              base_stack({v(2), z, -v(0)}, -1),
     161         1220 :                              base_stack({-v(1), v(0), z}, -1)},
     162         1830 :                             -2));
     163         4575 : }
     164              : 
     165              : template <class Derived>
     166              : Derived
     167          381 : R2Base<Derived>::identity(const TensorOptions & options)
     168              : {
     169          381 :   return Derived(at::eye(3, options), 0);
     170              : }
     171              : 
     172              : template <class Derived>
     173              : Derived
     174           18 : R2Base<Derived>::rotate(const Rot & r) const
     175              : {
     176           18 :   return rotate(r.euler_rodrigues());
     177              : }
     178              : 
     179              : template <class Derived>
     180              : Derived
     181           43 : R2Base<Derived>::rotate(const R2 & R) const
     182              : {
     183           43 :   return R * R2(*this) * R.transpose();
     184              : }
     185              : 
     186              : template <class Derived>
     187              : R3
     188           11 : R2Base<Derived>::drotate(const Rot & r) const
     189              : {
     190           11 :   auto R = r.euler_rodrigues();
     191           11 :   auto F = r.deuler_rodrigues();
     192              : 
     193           44 :   return R3(at::einsum("...itl,...tm,...jm", {F, *this, R}) +
     194          110 :             at::einsum("...ik,...kt,...jtl", {R, *this, F}));
     195           33 : }
     196              : 
     197              : template <class Derived>
     198              : R4
     199            9 : R2Base<Derived>::drotate(const R2 & R) const
     200              : {
     201            9 :   auto I = R2::identity(R.options());
     202           27 :   return R4(at::einsum("...ik,...jl", {I, R * this->transpose()}) +
     203           90 :             at::einsum("...jk,...il", {I, R * R2(*this)}));
     204           45 : }
     205              : 
     206              : template <class Derived>
     207              : Scalar
     208           81 : R2Base<Derived>::operator()(Size i, Size j) const
     209              : {
     210          243 :   return PrimitiveTensor<Derived, 3, 3>::base_index({i, j});
     211           81 : }
     212              : 
     213              : template <class Derived>
     214              : Scalar
     215            4 : R2Base<Derived>::det() const
     216              : {
     217            4 :   const auto comps = at::split(this->base_flatten(), 1, -1);
     218            4 :   const auto & a = comps[0];
     219            4 :   const auto & b = comps[1];
     220            4 :   const auto & c = comps[2];
     221            4 :   const auto & d = comps[3];
     222            4 :   const auto & e = comps[4];
     223            4 :   const auto & f = comps[5];
     224            4 :   const auto & g = comps[6];
     225            4 :   const auto & h = comps[7];
     226            4 :   const auto & i = comps[8];
     227            4 :   const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
     228            8 :   return Scalar(det.reshape(this->batch_sizes().concrete()), this->batch_sizes());
     229            4 : }
     230              : 
     231              : template <class Derived>
     232              : Scalar
     233            2 : R2Base<Derived>::inner(const R2 & other) const
     234              : {
     235            2 :   return base_sum(this->base_flatten() * other.base_flatten());
     236              : }
     237              : 
     238              : template <class Derived>
     239              : Derived
     240           97 : R2Base<Derived>::inverse() const
     241              : {
     242           97 :   const auto comps = at::split(this->base_flatten(), 1, -1);
     243           97 :   const auto & a = comps[0];
     244           97 :   const auto & b = comps[1];
     245           97 :   const auto & c = comps[2];
     246           97 :   const auto & d = comps[3];
     247           97 :   const auto & e = comps[4];
     248           97 :   const auto & f = comps[5];
     249           97 :   const auto & g = comps[6];
     250           97 :   const auto & h = comps[7];
     251           97 :   const auto & i = comps[8];
     252           97 :   const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
     253           97 :   const auto cof00 = e * i - h * f;
     254           97 :   const auto cof01 = -(d * i - g * f);
     255           97 :   const auto cof02 = d * h - g * e;
     256           97 :   const auto cof10 = -(b * i - h * c);
     257           97 :   const auto cof11 = a * i - g * c;
     258           97 :   const auto cof12 = -(a * h - g * b);
     259           97 :   const auto cof20 = b * f - e * c;
     260           97 :   const auto cof21 = -(a * f - d * c);
     261           97 :   const auto cof22 = a * e - d * b;
     262          388 :   const auto coft0 = at::cat({cof00, cof10, cof20}, -1);
     263          388 :   const auto coft1 = at::cat({cof01, cof11, cof21}, -1);
     264          388 :   const auto coft2 = at::cat({cof02, cof12, cof22}, -1);
     265          388 :   const auto coft = at::stack({coft0, coft1, coft2}, -2);
     266           97 :   const auto inv = coft / det.unsqueeze(-1);
     267          194 :   return Derived(inv, this->batch_sizes());
     268          485 : }
     269              : 
     270              : template <class Derived>
     271              : Derived
     272          198 : R2Base<Derived>::transpose() const
     273              : {
     274          198 :   return TensorBase<Derived>::base_transpose(0, 1);
     275              : }
     276              : 
     277              : template <class Derived1, class Derived2, typename, typename>
     278              : Vec
     279          177 : operator*(const Derived1 & A, const Derived2 & b)
     280              : {
     281          177 :   neml_assert_batch_broadcastable_dbg(A, b);
     282          708 :   return Vec(at::einsum("...ik,...k", {A, b}));
     283          177 : }
     284              : 
     285              : template <class Derived1, class Derived2, typename, typename>
     286              : R2
     287         1479 : operator*(const Derived1 & A, const Derived2 & B)
     288              : {
     289         1479 :   neml_assert_broadcastable_dbg(A, B);
     290         5916 :   return R2(at::einsum("...ik,...kj", {A, B}));
     291         1479 : }
     292              : 
     293              : // template instantiation
     294              : 
     295              : // derived classes
     296              : template class R2Base<R2>;
     297              : 
     298              : // products
     299              : template Vec operator*(const R2 & A, const Vec & b);
     300              : template R2 operator*(const R2 & A, const R2 & B);
     301              : } // namespace neml2
        

Generated by: LCOV version 2.0-1