LCOV - code coverage report
Current view: top level - tensors - R2Base.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 97.8 % 136 133
Test Date: 2025-10-02 16:03:03 Functions: 95.8 % 24 23

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/tensors/R2Base.h"
      26              : #include "neml2/tensors/R2.h"
      27              : #include "neml2/tensors/Scalar.h"
      28              : #include "neml2/tensors/Vec.h"
      29              : #include "neml2/tensors/SR2.h"
      30              : #include "neml2/tensors/R3.h"
      31              : #include "neml2/tensors/R4.h"
      32              : #include "neml2/tensors/Rot.h"
      33              : #include "neml2/tensors/WR2.h"
      34              : #include "neml2/tensors/assertions.h"
      35              : #include "neml2/tensors/functions/stack.h"
      36              : #include "neml2/tensors/functions/sum.h"
      37              : 
      38              : namespace neml2
      39              : {
      40              : template <class Derived>
      41              : Derived
      42           26 : R2Base<Derived>::fill(const CScalar & a, const TensorOptions & options)
      43              : {
      44           26 :   return R2Base<Derived>::fill(Scalar(a, options));
      45              : }
      46              : 
      47              : template <class Derived>
      48              : Derived
      49           51 : R2Base<Derived>::fill(const Scalar & a)
      50              : {
      51           51 :   auto zero = Scalar::zeros_like(a);
      52          306 :   return Derived(base_stack({base_stack({a, zero, zero}, -1),
      53          204 :                              base_stack({zero, a, zero}, -1),
      54          204 :                              base_stack({zero, zero, a}, -1)},
      55          306 :                             -2));
      56          765 : }
      57              : 
      58              : template <class Derived>
      59              : Derived
      60            2 : R2Base<Derived>::fill(const CScalar & a11,
      61              :                       const CScalar & a22,
      62              :                       const CScalar & a33,
      63              :                       const TensorOptions & options)
      64              : {
      65            2 :   return R2Base<Derived>::fill(Scalar(a11, options), Scalar(a22, options), Scalar(a33, options));
      66              : }
      67              : 
      68              : template <class Derived>
      69              : Derived
      70            5 : R2Base<Derived>::fill(const Scalar & a11, const Scalar & a22, const Scalar & a33)
      71              : {
      72            5 :   auto zero = Scalar::zeros_like(a11);
      73           30 :   return Derived(base_stack({base_stack({a11, zero, zero}, -1),
      74           20 :                              base_stack({zero, a22, zero}, -1),
      75           20 :                              base_stack({zero, zero, a33}, -1)},
      76           30 :                             -2));
      77           75 : }
      78              : 
      79              : template <class Derived>
      80              : Derived
      81            2 : R2Base<Derived>::fill(const CScalar & a11,
      82              :                       const CScalar & a22,
      83              :                       const CScalar & a33,
      84              :                       const CScalar & a23,
      85              :                       const CScalar & a13,
      86              :                       const CScalar & a12,
      87              :                       const TensorOptions & options)
      88              : {
      89            4 :   return R2Base<Derived>::fill(Scalar(a11, options),
      90            4 :                                Scalar(a22, options),
      91            4 :                                Scalar(a33, options),
      92            4 :                                Scalar(a23, options),
      93            4 :                                Scalar(a13, options),
      94            6 :                                Scalar(a12, options));
      95              : }
      96              : 
      97              : template <class Derived>
      98              : Derived
      99            5 : R2Base<Derived>::fill(const Scalar & a11,
     100              :                       const Scalar & a22,
     101              :                       const Scalar & a33,
     102              :                       const Scalar & a23,
     103              :                       const Scalar & a13,
     104              :                       const Scalar & a12)
     105              : {
     106           30 :   return Derived(base_stack({base_stack({a11, a12, a13}, -1),
     107           20 :                              base_stack({a12, a22, a23}, -1),
     108           20 :                              base_stack({a13, a23, a33}, -1)},
     109           30 :                             -2));
     110           75 : }
     111              : 
     112              : template <class Derived>
     113              : Derived
     114           35 : R2Base<Derived>::fill(const CScalar & a11,
     115              :                       const CScalar & a12,
     116              :                       const CScalar & a13,
     117              :                       const CScalar & a21,
     118              :                       const CScalar & a22,
     119              :                       const CScalar & a23,
     120              :                       const CScalar & a31,
     121              :                       const CScalar & a32,
     122              :                       const CScalar & a33,
     123              :                       const TensorOptions & options)
     124              : {
     125           70 :   return R2Base<Derived>::fill(Scalar(a11, options),
     126           70 :                                Scalar(a12, options),
     127           70 :                                Scalar(a13, options),
     128           70 :                                Scalar(a21, options),
     129           70 :                                Scalar(a22, options),
     130           70 :                                Scalar(a23, options),
     131           70 :                                Scalar(a31, options),
     132           70 :                                Scalar(a32, options),
     133          105 :                                Scalar(a33, options));
     134              : }
     135              : 
     136              : template <class Derived>
     137              : Derived
     138          143 : R2Base<Derived>::fill(const Scalar & a11,
     139              :                       const Scalar & a12,
     140              :                       const Scalar & a13,
     141              :                       const Scalar & a21,
     142              :                       const Scalar & a22,
     143              :                       const Scalar & a23,
     144              :                       const Scalar & a31,
     145              :                       const Scalar & a32,
     146              :                       const Scalar & a33)
     147              : {
     148          858 :   return Derived(base_stack({base_stack({a11, a12, a13}, -1),
     149          572 :                              base_stack({a21, a22, a23}, -1),
     150          572 :                              base_stack({a31, a32, a33}, -1)},
     151          858 :                             -2));
     152         2145 : }
     153              : 
     154              : template <class Derived>
     155              : Derived
     156          305 : R2Base<Derived>::skew(const Vec & v)
     157              : {
     158          305 :   const auto z = Scalar::zeros_like(v(0));
     159         1830 :   return Derived(base_stack({base_stack({z, -v(2), v(1)}, -1),
     160         1220 :                              base_stack({v(2), z, -v(0)}, -1),
     161         1220 :                              base_stack({-v(1), v(0), z}, -1)},
     162         1830 :                             -2));
     163         4575 : }
     164              : 
     165              : template <class Derived>
     166              : Derived
     167          381 : R2Base<Derived>::identity(const TensorOptions & options)
     168              : {
     169          381 :   return Derived(at::eye(3, options), 0);
     170              : }
     171              : 
     172              : template <class Derived>
     173              : Derived
     174           18 : R2Base<Derived>::rotate(const Rot & r) const
     175              : {
     176           18 :   return rotate(r.euler_rodrigues());
     177              : }
     178              : 
     179              : template <class Derived>
     180              : Derived
     181           43 : R2Base<Derived>::rotate(const R2 & R) const
     182              : {
     183           43 :   return R * R2(*this) * R.transpose();
     184              : }
     185              : 
     186              : template <class Derived>
     187              : R3
     188           11 : R2Base<Derived>::drotate(const Rot & r) const
     189              : {
     190           11 :   auto R = r.euler_rodrigues();
     191           11 :   auto F = r.deuler_rodrigues();
     192              : 
     193           44 :   return R3(at::einsum("...itl,...tm,...jm", {F, *this, R}) +
     194          110 :             at::einsum("...ik,...kt,...jtl", {R, *this, F}));
     195           33 : }
     196              : 
     197              : template <class Derived>
     198              : R4
     199            9 : R2Base<Derived>::drotate(const R2 & R) const
     200              : {
     201            9 :   auto I = R2::identity(R.options());
     202           27 :   return R4(at::einsum("...ik,...jl", {I, R * this->transpose()}) +
     203           90 :             at::einsum("...jk,...il", {I, R * R2(*this)}));
     204           45 : }
     205              : 
     206              : template <class Derived>
     207              : Scalar
     208           81 : R2Base<Derived>::operator()(Size i, Size j) const
     209              : {
     210          243 :   return this->base_index({i, j});
     211           81 : }
     212              : 
     213              : template <class Derived>
     214              : Vec
     215            0 : R2Base<Derived>::row(Size i) const
     216              : {
     217            0 :   return Vec(this->base_index({i, indexing::Slice()}), this->batch_sizes());
     218            0 : }
     219              : 
     220              : template <class Derived>
     221              : Vec
     222           60 : R2Base<Derived>::col(Size i) const
     223              : {
     224          300 :   return Vec(this->base_index({indexing::Slice(), i}), this->batch_sizes());
     225          120 : }
     226              : 
     227              : template <class Derived>
     228              : Scalar
     229            4 : R2Base<Derived>::det() const
     230              : {
     231            4 :   const auto comps = at::split(this->base_flatten(), 1, -1);
     232            4 :   const auto & a = comps[0];
     233            4 :   const auto & b = comps[1];
     234            4 :   const auto & c = comps[2];
     235            4 :   const auto & d = comps[3];
     236            4 :   const auto & e = comps[4];
     237            4 :   const auto & f = comps[5];
     238            4 :   const auto & g = comps[6];
     239            4 :   const auto & h = comps[7];
     240            4 :   const auto & i = comps[8];
     241            4 :   const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
     242            8 :   return Scalar(det.reshape(this->batch_sizes().concrete()), this->batch_sizes());
     243            4 : }
     244              : 
     245              : template <class Derived>
     246              : Scalar
     247            4 : R2Base<Derived>::inner(const R2 & other) const
     248              : {
     249            4 :   return base_sum(this->base_flatten() * other.base_flatten());
     250              : }
     251              : 
     252              : template <class Derived>
     253              : R4
     254           15 : R2Base<Derived>::outer(const R2 & other) const
     255              : {
     256           15 :   return this->base_unsqueeze(-1).base_unsqueeze(-1) * other.base_unsqueeze(0).base_unsqueeze(0);
     257              : }
     258              : 
     259              : template <class Derived>
     260              : Derived
     261           97 : R2Base<Derived>::inverse() const
     262              : {
     263           97 :   const auto comps = at::split(this->base_flatten(), 1, -1);
     264           97 :   const auto & a = comps[0];
     265           97 :   const auto & b = comps[1];
     266           97 :   const auto & c = comps[2];
     267           97 :   const auto & d = comps[3];
     268           97 :   const auto & e = comps[4];
     269           97 :   const auto & f = comps[5];
     270           97 :   const auto & g = comps[6];
     271           97 :   const auto & h = comps[7];
     272           97 :   const auto & i = comps[8];
     273           97 :   const auto det = a * (e * i - h * f) - b * (d * i - g * f) + c * (d * h - e * g);
     274           97 :   const auto cof00 = e * i - h * f;
     275           97 :   const auto cof01 = -(d * i - g * f);
     276           97 :   const auto cof02 = d * h - g * e;
     277           97 :   const auto cof10 = -(b * i - h * c);
     278           97 :   const auto cof11 = a * i - g * c;
     279           97 :   const auto cof12 = -(a * h - g * b);
     280           97 :   const auto cof20 = b * f - e * c;
     281           97 :   const auto cof21 = -(a * f - d * c);
     282           97 :   const auto cof22 = a * e - d * b;
     283          388 :   const auto coft0 = at::cat({cof00, cof10, cof20}, -1);
     284          388 :   const auto coft1 = at::cat({cof01, cof11, cof21}, -1);
     285          388 :   const auto coft2 = at::cat({cof02, cof12, cof22}, -1);
     286          388 :   const auto coft = at::stack({coft0, coft1, coft2}, -2);
     287           97 :   const auto inv = coft / det.unsqueeze(-1);
     288          194 :   return Derived(inv, this->batch_sizes());
     289          485 : }
     290              : 
     291              : template <class Derived>
     292              : Derived
     293          262 : R2Base<Derived>::transpose() const
     294              : {
     295          262 :   return TensorBase<Derived>::base_transpose(0, 1);
     296              : }
     297              : 
     298              : template <class Derived1, class Derived2, typename, typename>
     299              : Vec
     300          185 : operator*(const Derived1 & A, const Derived2 & b)
     301              : {
     302          185 :   neml_assert_batch_broadcastable_dbg(A, b);
     303          740 :   return Vec(at::einsum("...ik,...k", {A, b}));
     304          185 : }
     305              : 
     306              : template <class Derived1, class Derived2, typename, typename>
     307              : R2
     308         1479 : operator*(const Derived1 & A, const Derived2 & B)
     309              : {
     310         1479 :   neml_assert_broadcastable_dbg(A, B);
     311         5916 :   return R2(at::einsum("...ik,...kj", {A, B}));
     312         1479 : }
     313              : 
     314              : // template instantiation
     315              : 
     316              : // derived classes
     317              : template class R2Base<R2>;
     318              : 
     319              : // products
     320              : template Vec operator*(const R2 & A, const Vec & b);
     321              : template R2 operator*(const R2 & A, const R2 & B);
     322              : } // namespace neml2
        

Generated by: LCOV version 2.0-1