LCOV - code coverage report
Current view: top level - tensors - SR2.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 100.0 % 97 97
Test Date: 2025-06-29 01:25:44 Functions: 100.0 % 24 24

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/tensors/SR2.h"
      26              : #include "neml2/tensors/Scalar.h"
      27              : #include "neml2/tensors/R2.h"
      28              : #include "neml2/tensors/R3.h"
      29              : #include "neml2/tensors/SFR3.h"
      30              : #include "neml2/tensors/SSR4.h"
      31              : #include "neml2/tensors/SFFR4.h"
      32              : #include "neml2/tensors/Rot.h"
      33              : #include "neml2/tensors/SWR4.h"
      34              : #include "neml2/tensors/WR2.h"
      35              : #include "neml2/tensors/R4.h"
      36              : #include "neml2/tensors/assertions.h"
      37              : 
      38              : #include "neml2/tensors/mandel_notation.h"
      39              : #include "neml2/tensors/functions/sqrt.h"
      40              : #include "neml2/tensors/functions/sum.h"
      41              : #include "neml2/tensors/functions/stack.h"
      42              : #include "neml2/tensors/functions/linalg/vecdot.h"
      43              : 
      44              : namespace neml2
      45              : {
      46              : 
      47           91 : SR2::SR2(const R2 & T)
      48           91 :   : SR2(full_to_mandel((T + T.transpose()) / 2.0))
      49              : {
      50           91 : }
      51              : 
      52              : SR2
      53            3 : SR2::fill(const CScalar & a, const TensorOptions & options)
      54              : {
      55            3 :   return SR2::fill(Scalar(a, options));
      56              : }
      57              : 
      58              : SR2
      59          119 : SR2::fill(const Scalar & a)
      60              : {
      61          119 :   auto zero = Scalar::zeros_like(a);
      62         1071 :   return SR2(base_stack({a, a, a, zero, zero, zero}, -1));
      63          238 : }
      64              : 
      65              : SR2
      66           18 : SR2::fill(const CScalar & a11,
      67              :           const CScalar & a22,
      68              :           const CScalar & a33,
      69              :           const TensorOptions & options)
      70              : {
      71           18 :   return SR2::fill(Scalar(a11, options), Scalar(a22, options), Scalar(a33, options));
      72              : }
      73              : 
      74              : SR2
      75           49 : SR2::fill(const Scalar & a11, const Scalar & a22, const Scalar & a33)
      76              : {
      77           49 :   auto zero = Scalar::zeros_like(a11);
      78          441 :   return SR2(base_stack({a11, a22, a33, zero, zero, zero}, -1));
      79           98 : }
      80              : 
      81              : SR2
      82           23 : SR2::fill(const CScalar & a11,
      83              :           const CScalar & a22,
      84              :           const CScalar & a33,
      85              :           const CScalar & a23,
      86              :           const CScalar & a13,
      87              :           const CScalar & a12,
      88              :           const TensorOptions & options)
      89              : {
      90           46 :   return SR2::fill(Scalar(a11, options),
      91           46 :                    Scalar(a22, options),
      92           46 :                    Scalar(a33, options),
      93           46 :                    Scalar(a23, options),
      94           46 :                    Scalar(a13, options),
      95           69 :                    Scalar(a12, options));
      96              : }
      97              : 
      98              : SR2
      99          110 : SR2::fill(const Scalar & a11,
     100              :           const Scalar & a22,
     101              :           const Scalar & a33,
     102              :           const Scalar & a23,
     103              :           const Scalar & a13,
     104              :           const Scalar & a12)
     105              : {
     106          990 :   return SR2(base_stack(
     107          550 :       {a11, a22, a33, mandel_factor(3) * a23, mandel_factor(4) * a13, mandel_factor(5) * a12}, -1));
     108          440 : }
     109              : 
     110              : SR2
     111           28 : SR2::identity(const TensorOptions & options)
     112              : {
     113          196 :   return SR2::create({1, 1, 1, 0, 0, 0}, options);
     114           28 : }
     115              : 
     116              : SSR4
     117           56 : SR2::identity_map(const TensorOptions & options)
     118              : {
     119           56 :   return SSR4::identity_sym(options);
     120              : }
     121              : 
     122              : SR2
     123            8 : SR2::rotate(const Rot & r) const
     124              : {
     125            8 :   return R2(*this).rotate(r);
     126              : }
     127              : 
     128              : SR2
     129           22 : SR2::rotate(const R2 & R) const
     130              : {
     131           22 :   return R2(*this).rotate(R);
     132              : }
     133              : 
     134              : SFR3
     135            4 : SR2::drotate(const Rot & r) const
     136              : {
     137            4 :   auto dR = R2(*this).drotate(r);
     138            8 :   return full_to_mandel(dR);
     139            4 : }
     140              : 
     141              : SFFR4
     142            8 : SR2::drotate(const R2 & R) const
     143              : {
     144            8 :   auto dR = R2(*this).drotate(R);
     145           16 :   return full_to_mandel(dR);
     146            8 : }
     147              : 
     148              : Scalar
     149            9 : SR2::operator()(Size i, Size j) const
     150              : {
     151            9 :   Size a = mandel_reverse_index[i][j];
     152           27 :   return base_index({a}) / mandel_factor(a);
     153            9 : }
     154              : 
     155              : Scalar
     156          132 : SR2::tr() const
     157              : {
     158          660 :   return Scalar(base_sum(base_index({indexing::Slice(0, 3)}), -1), batch_sizes());
     159          264 : }
     160              : 
     161              : SR2
     162          101 : SR2::vol() const
     163              : {
     164          202 :   return SR2::fill(tr()) / 3;
     165              : }
     166              : 
     167              : SR2
     168           70 : SR2::dev() const
     169              : {
     170           70 :   return *this - vol();
     171              : }
     172              : 
     173              : Scalar
     174           71 : SR2::inner(const SR2 & other) const
     175              : {
     176           71 :   return linalg::vecdot(*this, other);
     177              : }
     178              : 
     179              : Scalar
     180           61 : SR2::norm_sq() const
     181              : {
     182           61 :   return inner(*this);
     183              : }
     184              : 
     185              : Scalar
     186           59 : SR2::norm(const CScalar & eps) const
     187              : {
     188           59 :   return neml2::sqrt(norm_sq() + eps);
     189              : }
     190              : 
     191              : SSR4
     192           23 : SR2::outer(const SR2 & other) const
     193              : {
     194           23 :   neml_assert_broadcastable_dbg(*this, other);
     195          115 :   return SSR4(at::einsum("...i,...j", {*this, other}), utils::broadcast_batch_dim(*this, other));
     196           23 : }
     197              : 
     198              : Scalar
     199            2 : SR2::det() const
     200              : {
     201            2 :   const auto comps = at::split(*this, 1, -1);
     202            2 :   const auto & a = comps[0];
     203            2 :   const auto & e = comps[1];
     204            2 :   const auto & i = comps[2];
     205            2 :   const auto f = comps[3] / mandel_factor(3);
     206            2 :   const auto c = comps[4] / mandel_factor(4);
     207            2 :   const auto b = comps[5] / mandel_factor(5);
     208            2 :   const auto det = a * (e * i - f * f) - b * (b * i - c * f) + c * (b * f - e * c);
     209            6 :   return Scalar(det.reshape(batch_sizes().concrete()), batch_sizes());
     210            2 : }
     211              : 
     212              : SR2
     213            2 : SR2::inverse() const
     214              : {
     215            2 :   const auto comps = at::split(*this, 1, -1);
     216            2 :   const auto & a = comps[0];
     217            2 :   const auto & e = comps[1];
     218            2 :   const auto & i = comps[2];
     219            2 :   const auto f = comps[3] / mandel_factor(3);
     220            2 :   const auto c = comps[4] / mandel_factor(4);
     221            2 :   const auto b = comps[5] / mandel_factor(5);
     222            2 :   const auto det = a * (e * i - f * f) - b * (b * i - c * f) + c * (b * f - e * c);
     223            2 :   const auto cof00 = e * i - f * f;
     224            2 :   const auto cof01 = -(b * i - c * f);
     225            2 :   const auto cof02 = b * f - c * e;
     226            2 :   const auto cof11 = a * i - c * c;
     227            2 :   const auto cof12 = -(a * f - c * b);
     228            2 :   const auto cof22 = a * e - b * b;
     229           14 :   const auto cof = at::cat({cof00,
     230              :                             cof11,
     231              :                             cof22,
     232              :                             mandel_factor(3) * cof12,
     233              :                             mandel_factor(4) * cof02,
     234              :                             mandel_factor(5) * cof01},
     235           16 :                            -1);
     236            2 :   const auto inv = cof / det;
     237            4 :   return SR2(inv, batch_sizes());
     238            6 : }
     239              : 
     240              : SR2
     241            2 : SR2::transpose() const
     242              : {
     243            2 :   return SR2(*this, batch_sizes());
     244              : }
     245              : 
     246              : } // namespace neml2
        

Generated by: LCOV version 2.0-1