LCOV - code coverage report
Current view: top level - models - Variable.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 82.1 % 252 207
Test Date: 2025-10-02 16:03:03 Functions: 32.6 % 285 93

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/models/Variable.h"
      26              : #include "neml2/models/Model.h"
      27              : #include "neml2/models/DependencyResolver.h"
      28              : #include "neml2/tensors/tensors.h"
      29              : #include "neml2/misc/assertions.h"
      30              : #include "neml2/tensors/functions/bmm.h"
      31              : #include "neml2/jit/utils.h"
      32              : #include "neml2/jit/TraceableTensorShape.h"
      33              : 
      34              : namespace neml2
      35              : {
      36         1764 : VariableBase::VariableBase(VariableName name_in, Model * owner, TensorShapeRef list_shape)
      37         1764 :   : _name(std::move(name_in)),
      38         1764 :     _owner(owner),
      39         3528 :     _list_sizes(list_shape)
      40              : {
      41         1764 : }
      42              : 
      43              : const Model &
      44        12303 : VariableBase::owner() const
      45              : {
      46        12303 :   neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
      47        12303 :   return *_owner;
      48              : }
      49              : 
      50              : Model &
      51        10217 : VariableBase::owner()
      52              : {
      53        10217 :   neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
      54        10217 :   return *_owner;
      55              : }
      56              : 
      57              : bool
      58          845 : VariableBase::is_state() const
      59              : {
      60          845 :   return _name.is_state();
      61              : }
      62              : 
      63              : bool
      64          166 : VariableBase::is_old_state() const
      65              : {
      66          166 :   return _name.is_old_state();
      67              : }
      68              : 
      69              : bool
      70          171 : VariableBase::is_force() const
      71              : {
      72          171 :   return _name.is_force();
      73              : }
      74              : 
      75              : bool
      76           48 : VariableBase::is_old_force() const
      77              : {
      78           48 :   return _name.is_old_force();
      79              : }
      80              : 
      81              : bool
      82          136 : VariableBase::is_residual() const
      83              : {
      84          136 :   return _name.is_residual();
      85              : }
      86              : 
      87              : bool
      88           90 : VariableBase::is_parameter() const
      89              : {
      90           90 :   return _name.is_parameter();
      91              : }
      92              : 
      93              : bool
      94          397 : VariableBase::is_solve_dependent() const
      95              : {
      96          397 :   return is_state() || is_residual() || is_parameter();
      97              : }
      98              : 
      99              : bool
     100         1573 : VariableBase::is_dependent() const
     101              : {
     102         1573 :   return !currently_solving_nonlinear_system() || is_solve_dependent();
     103              : }
     104              : 
     105              : TensorOptions
     106          399 : VariableBase::options() const
     107              : {
     108          399 :   return tensor().options();
     109              : }
     110              : 
     111              : Dtype
     112           48 : VariableBase::scalar_type() const
     113              : {
     114           48 :   return tensor().scalar_type();
     115              : }
     116              : 
     117              : Device
     118            0 : VariableBase::device() const
     119              : {
     120            0 :   return tensor().device();
     121              : }
     122              : 
     123              : Size
     124            0 : VariableBase::dim() const
     125              : {
     126            0 :   return tensor().dim();
     127              : }
     128              : 
     129              : TensorShapeRef
     130            0 : VariableBase::sizes() const
     131              : {
     132            0 :   return tensor().sizes();
     133              : }
     134              : 
     135              : Size
     136            0 : VariableBase::size(Size dim) const
     137              : {
     138            0 :   return tensor().size(dim);
     139              : }
     140              : 
     141              : bool
     142            0 : VariableBase::batched() const
     143              : {
     144            0 :   return tensor().batched();
     145              : }
     146              : 
     147              : Size
     148        36971 : VariableBase::batch_dim() const
     149              : {
     150        36971 :   return tensor().batch_dim();
     151              : }
     152              : 
     153              : Size
     154       674697 : VariableBase::list_dim() const
     155              : {
     156       674697 :   return Size(list_sizes().size());
     157              : }
     158              : 
     159              : Size
     160            0 : VariableBase::base_dim() const
     161              : {
     162            0 :   return Size(base_sizes().size());
     163              : }
     164              : 
     165              : TraceableTensorShape
     166            1 : VariableBase::batch_sizes() const
     167              : {
     168            1 :   return tensor().batch_sizes();
     169              : }
     170              : 
     171              : TensorShapeRef
     172       729996 : VariableBase::list_sizes() const
     173              : {
     174       729996 :   return _list_sizes;
     175              : }
     176              : 
     177              : TraceableSize
     178            0 : VariableBase::batch_size(Size dim) const
     179              : {
     180            0 :   return tensor().batch_size(dim);
     181              : }
     182              : 
     183              : Size
     184            0 : VariableBase::base_size(Size dim) const
     185              : {
     186            0 :   return base_sizes()[dim];
     187              : }
     188              : 
     189              : Size
     190            1 : VariableBase::list_size(Size dim) const
     191              : {
     192            1 :   return list_sizes()[dim];
     193              : }
     194              : 
     195              : Size
     196            0 : VariableBase::base_storage() const
     197              : {
     198            0 :   return utils::storage_size(base_sizes());
     199              : }
     200              : 
     201              : Size
     202         2869 : VariableBase::assembly_storage() const
     203              : {
     204         2869 :   return utils::storage_size(list_sizes()) * utils::storage_size(base_sizes());
     205              : }
     206              : 
     207              : bool
     208            0 : VariableBase::requires_grad() const
     209              : {
     210            0 :   return tensor().requires_grad();
     211              : }
     212              : 
     213              : Derivative
     214          816 : VariableBase::d(const VariableBase & var)
     215              : {
     216          816 :   neml_assert_dbg(owning(),
     217              :                   "Cannot assign derivative to a referencing variable '",
     218          816 :                   name(),
     219              :                   "' with respect to '",
     220          816 :                   var.name(),
     221              :                   "'.");
     222          816 :   return Derivative({assembly_storage(), var.assembly_storage()}, &_derivs[var.name()]);
     223              : }
     224              : 
     225              : Derivative
     226          233 : VariableBase::d(const VariableBase & var1, const VariableBase & var2)
     227              : {
     228          233 :   neml_assert_dbg(owning(),
     229              :                   "Cannot assign second derivative to a referencing variable '",
     230          233 :                   name(),
     231              :                   "' with respect to '",
     232          233 :                   var1.name(),
     233              :                   "' and '",
     234          233 :                   var2.name(),
     235              :                   "'.");
     236          932 :   return Derivative({assembly_storage(), var1.assembly_storage(), var2.assembly_storage()},
     237          233 :                     &_sec_derivs[var1.name()][var2.name()]);
     238              : }
     239              : 
     240              : void
     241            0 : VariableBase::request_AD(const VariableBase & u)
     242              : {
     243            0 :   owner().request_AD(*this, u);
     244            0 : }
     245              : 
     246              : void
     247            4 : VariableBase::request_AD(const std::vector<const VariableBase *> & us)
     248              : {
     249           18 :   for (const auto & u : us)
     250              :   {
     251           14 :     neml_assert(u, "Cannot request AD for a null variable.");
     252           14 :     owner().request_AD(*this, *u);
     253              :   }
     254            4 : }
     255              : 
     256              : void
     257            0 : VariableBase::request_AD(const VariableBase & u1, const VariableBase & u2)
     258              : {
     259            0 :   owner().request_AD(*this, u1, u2);
     260            0 : }
     261              : 
     262              : void
     263            3 : VariableBase::request_AD(const std::vector<const VariableBase *> & u1s,
     264              :                          const std::vector<const VariableBase *> & u2s)
     265              : {
     266           15 :   for (const auto & u1 : u1s)
     267           60 :     for (const auto & u2 : u2s)
     268              :     {
     269           48 :       neml_assert(u1, "Cannot request AD for a null variable.");
     270           48 :       neml_assert(u2, "Cannot request AD for a null variable.");
     271           48 :       owner().request_AD(*this, *u1, *u2);
     272              :     }
     273            3 : }
     274              : 
     275              : void
     276        20175 : VariableBase::clear()
     277              : {
     278        20175 :   neml_assert_dbg(owning(), "Cannot clear a referencing variable '", name(), "'.");
     279        20175 :   clear_derivatives();
     280        20175 : }
     281              : 
     282              : void
     283        20781 : VariableBase::clear_derivatives()
     284              : {
     285        20781 :   _derivs.clear();
     286        20781 :   _sec_derivs.clear();
     287        20781 : }
     288              : 
     289              : void
     290           85 : VariableBase::apply_chain_rule(const DependencyResolver<Model, VariableName> & dep)
     291              : {
     292          109 :   for (const auto & [model, var] : dep.outbound_items())
     293          109 :     if (var == name())
     294              :     {
     295           85 :       _derivs = total_derivatives(dep, model, var);
     296           85 :       return;
     297              :     }
     298              : }
     299              : 
     300              : void
     301           16 : VariableBase::apply_second_order_chain_rule(const DependencyResolver<Model, VariableName> & dep)
     302              : {
     303           16 :   for (const auto & [model, var] : dep.outbound_items())
     304           16 :     if (var == name())
     305              :     {
     306           16 :       _sec_derivs = total_second_derivatives(dep, model, var);
     307           16 :       return;
     308              :     }
     309              : }
     310              : 
     311              : static void
     312         1025 : assign_or_add(Tensor & dest, const Tensor & val)
     313              : {
     314         1025 :   if (dest.defined())
     315           72 :     dest = dest + val;
     316              :   else
     317          953 :     dest = val;
     318         1025 : }
     319              : 
     320              : ValueMap
     321          418 : VariableBase::total_derivatives(const DependencyResolver<Model, VariableName> & dep,
     322              :                                 Model * model,
     323              :                                 const VariableName & yvar) const
     324              : {
     325          418 :   ValueMap derivs;
     326              : 
     327         1121 :   for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
     328              :   {
     329          703 :     if (dep.inbound_items().count({model, uvar}))
     330          479 :       assign_or_add(derivs[uvar], dy_du);
     331              :     else
     332          448 :       for (const auto & depu : dep.item_providers().at({model, uvar}))
     333          639 :         for (const auto & [xvar, du_dx] : total_derivatives(dep, depu.parent, uvar))
     334          639 :           assign_or_add(derivs[xvar], bmm(dy_du, du_dx));
     335              :   }
     336              : 
     337          418 :   return derivs;
     338            0 : }
     339              : 
     340              : DerivMap
     341           48 : VariableBase::total_second_derivatives(const DependencyResolver<Model, VariableName> & dep,
     342              :                                        Model * model,
     343              :                                        const VariableName & yvar) const
     344              : {
     345           48 :   DerivMap sec_derivs;
     346              : 
     347           83 :   for (const auto & [u1var, d2y_du1] : model->output_variable(yvar).second_derivatives())
     348          131 :     for (const auto & [u2var, d2y_du1u2] : d2y_du1)
     349              :     {
     350           96 :       if (dep.inbound_items().count({model, u1var}) && dep.inbound_items().count({model, u2var}))
     351           25 :         assign_or_add(sec_derivs[u1var][u2var], d2y_du1u2);
     352           71 :       else if (dep.inbound_items().count({model, u1var}))
     353           38 :         for (const auto & depu2 : dep.item_providers().at({model, u2var}))
     354           38 :           for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
     355           19 :             assign_or_add(sec_derivs[u1var][x2var],
     356           95 :                           Tensor(at::einsum("...ijq,...qk", {d2y_du1u2, du2_dxk}),
     357           19 :                                  utils::broadcast_batch_dim(d2y_du1u2, du2_dxk)));
     358           52 :       else if (dep.inbound_items().count({model, u2var}))
     359           36 :         for (const auto & depu1 : dep.item_providers().at({model, u1var}))
     360           36 :           for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
     361           18 :             assign_or_add(sec_derivs[x1var][u2var],
     362           90 :                           Tensor(at::einsum("...ipk,...pj", {d2y_du1u2, du1_dxj}),
     363           18 :                                  utils::broadcast_batch_dim(d2y_du1u2, du1_dxj)));
     364              :       else
     365           68 :         for (const auto & depu1 : dep.item_providers().at({model, u1var}))
     366           72 :           for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
     367           76 :             for (const auto & depu2 : dep.item_providers().at({model, u2var}))
     368           84 :               for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
     369           46 :                 assign_or_add(
     370           46 :                     sec_derivs[x1var][x2var],
     371          276 :                     Tensor(at::einsum("...ipq,...pj,...qk", {d2y_du1u2, du1_dxj, du2_dxk}),
     372           72 :                            utils::broadcast_batch_dim(d2y_du1u2, du1_dxj, du2_dxk)));
     373              :     }
     374              : 
     375          124 :   for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
     376           76 :     if (!dep.inbound_items().count({model, uvar}))
     377           64 :       for (const auto & depu : dep.item_providers().at({model, uvar}))
     378           47 :         for (const auto & [x1var, d2u_dx1] : total_second_derivatives(dep, depu.parent, uvar))
     379           38 :           for (const auto & [x2var, d2u_dx1x2] : d2u_dx1)
     380           23 :             assign_or_add(sec_derivs[x1var][x2var],
     381          115 :                           Tensor(at::einsum("...ip,...pjk", {dy_du, d2u_dx1x2}),
     382           32 :                                  utils::broadcast_batch_dim(dy_du, d2u_dx1x2)));
     383              : 
     384           48 :   return sec_derivs;
     385          106 : }
     386              : 
     387              : template <typename T>
     388              : TensorType
     389         1630 : Variable<T>::type() const
     390              : {
     391         1630 :   return TensorTypeEnum<T>::value;
     392              : }
     393              : 
     394              : template <typename T>
     395              : std::unique_ptr<VariableBase>
     396          538 : Variable<T>::clone(const VariableName & name, Model * owner) const
     397              : {
     398              :   if constexpr (std::is_same_v<T, Tensor>)
     399              :   {
     400              :     return std::move(std::make_unique<Variable<Tensor>>(
     401              :         name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes(), base_sizes()));
     402              :   }
     403              :   else
     404              :   {
     405         1614 :     return std::move(std::make_unique<Variable<T>>(
     406         2152 :         name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes()));
     407              :   }
     408              : }
     409              : 
     410              : template <typename T>
     411              : void
     412          805 : Variable<T>::ref(const VariableBase & var, bool ref_is_mutable)
     413              : {
     414          805 :   neml_assert(!_ref || ref() == var.ref(),
     415              :               "Variable '",
     416          805 :               name(),
     417              :               "' cannot reference another variable '",
     418          805 :               var.name(),
     419              :               "' after it has been assigned a reference. \nThe "
     420              :               "existing reference '",
     421          805 :               ref()->name(),
     422              :               "' was declared by model '",
     423          805 :               ref()->owner().name(),
     424              :               "'. \nThe new reference is declared by model '",
     425          805 :               var.owner().name(),
     426              :               "'.");
     427          805 :   neml_assert(&var != this, "Variable '", name(), "' cannot reference itself.");
     428          805 :   neml_assert(var.ref() != this,
     429              :               "Variable '",
     430          805 :               name(),
     431              :               "' cannot reference a variable that is referencing itself.");
     432          805 :   const auto * var_ptr = dynamic_cast<const Variable<T> *>(var.ref());
     433          805 :   neml_assert(var_ptr,
     434              :               "Variable ",
     435          805 :               name(),
     436              :               " of type ",
     437          805 :               type(),
     438              :               " failed to reference another variable named ",
     439          805 :               var.name(),
     440              :               " of type ",
     441          805 :               var.type(),
     442              :               ": Dynamic cast failure.");
     443          805 :   _ref = var_ptr;
     444          805 :   _ref_is_mutable |= ref_is_mutable;
     445          805 : }
     446              : 
     447              : template <typename T>
     448              : void
     449        20179 : Variable<T>::zero(const TensorOptions & options)
     450              : {
     451        20179 :   if (owning())
     452              :   {
     453              :     if constexpr (std::is_same_v<T, Tensor>)
     454              :       _value = T::zeros(list_sizes(), base_sizes(), options);
     455              :     else
     456        20179 :       _value = T::zeros(list_sizes(), options);
     457              :   }
     458              :   else
     459              :   {
     460            0 :     neml_assert_dbg(_ref_is_mutable,
     461              :                     "Model '",
     462            0 :                     owner().name(),
     463              :                     "' is trying to zero a variable '",
     464            0 :                     name(),
     465              :                     "' declared by model '",
     466            0 :                     ref()->owner().name(),
     467              :                     "' , but the referenced variable is not mutable.");
     468              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     469            0 :     const_cast<VariableBase *>(ref())->zero(options);
     470              :   }
     471        20179 : }
     472              : 
     473              : template <typename T>
     474              : void
     475        20513 : Variable<T>::set(const Tensor & val)
     476              : {
     477        20513 :   if (owning())
     478        15849 :     _value = T(val.base_reshape(utils::add_shapes(list_sizes(), base_sizes())),
     479        31698 :                utils::add_traceable_shapes(val.batch_sizes(), list_sizes()));
     480              :   else
     481              :   {
     482         4664 :     neml_assert_dbg(_ref_is_mutable,
     483              :                     "Model '",
     484         4664 :                     owner().name(),
     485              :                     "' is trying to assign value to a variable '",
     486         4664 :                     name(),
     487              :                     "' declared by model '",
     488         4664 :                     ref()->owner().name(),
     489              :                     "' , but the referenced variable is not mutable.");
     490              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     491         4664 :     const_cast<VariableBase *>(ref())->set(val);
     492              :   }
     493        20513 : }
     494              : 
     495              : template <typename T>
     496              : void
     497        14389 : Variable<T>::set(const ATensor & val, bool force)
     498              : {
     499        14389 :   if (owning())
     500              :   {
     501              :     if constexpr (std::is_same_v<T, Tensor>)
     502              :       _value = T(val, val.dim() - base_dim());
     503              :     else
     504         8898 :       _value = T(val);
     505              :   }
     506              :   else
     507              :   {
     508         5491 :     neml_assert_dbg(_ref_is_mutable || force,
     509              :                     "Model '",
     510         5491 :                     owner().name(),
     511              :                     "' is trying to assign value to a variable '",
     512         5491 :                     name(),
     513              :                     "' declared by model '",
     514         5491 :                     ref()->owner().name(),
     515              :                     "' , but the referenced variable is not mutable.");
     516              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     517         5491 :     const_cast<VariableBase *>(ref())->set(val);
     518              :   }
     519        14389 : }
     520              : 
     521              : template <typename T>
     522              : Tensor
     523            0 : Variable<T>::get() const
     524              : {
     525            0 :   return tensor().base_flatten();
     526              : }
     527              : 
     528              : template <typename T>
     529              : Tensor
     530      1092351 : Variable<T>::tensor() const
     531              : {
     532      1092351 :   if (owning())
     533              :   {
     534       674697 :     neml_assert_dbg(_value.defined(), "Variable '", name(), "' has undefined value.");
     535       674697 :     auto batch_sizes = _value.batch_sizes().slice(0, _value.batch_dim() - list_dim());
     536       674697 :     return Tensor(_value, batch_sizes);
     537       674697 :   }
     538              : 
     539       417654 :   return ref()->tensor();
     540              : }
     541              : 
     542              : template <typename T>
     543              : void
     544            6 : Variable<T>::requires_grad_(bool req)
     545              : {
     546            6 :   if (owning())
     547            6 :     _value.requires_grad_(req);
     548              :   else
     549              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     550            0 :     const_cast<VariableBase *>(ref())->requires_grad_(req);
     551            6 : }
     552              : 
     553              : template <typename T>
     554              : void
     555          684 : Variable<T>::operator=(const Tensor & val)
     556              : {
     557          684 :   if (owning())
     558          684 :     _value = T(val);
     559              :   else
     560              :   {
     561            0 :     neml_assert_dbg(_ref_is_mutable,
     562              :                     "Model '",
     563            0 :                     owner().name(),
     564              :                     "' is trying to assign value to a variable '",
     565            0 :                     name(),
     566              :                     "' declared by model '",
     567            0 :                     ref()->owner().name(),
     568              :                     "' , but the referenced variable is not mutable.");
     569              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     570            0 :     *const_cast<VariableBase *>(ref()) = val;
     571              :   }
     572          684 : }
     573              : 
     574              : template <typename T>
     575              : void
     576        20175 : Variable<T>::clear()
     577              : {
     578        20175 :   if (owning())
     579              :   {
     580        20175 :     VariableBase::clear();
     581        20175 :     _value = T();
     582              :   }
     583              :   else
     584              :   {
     585            0 :     neml_assert_dbg(_ref_is_mutable,
     586              :                     "Model '",
     587            0 :                     owner().name(),
     588              :                     "' is trying to clear a variable '",
     589            0 :                     name(),
     590              :                     "' declared by model '",
     591            0 :                     ref()->owner().name(),
     592              :                     "' , but the referenced variable is not mutable.");
     593              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     594            0 :     const_cast<VariableBase *>(ref())->clear();
     595              :   }
     596        20175 : }
     597              : 
     598              : #define INSTANTIATE_VARIABLE(T) template class Variable<T>
     599              : FOR_ALL_PRIMITIVETENSOR(INSTANTIATE_VARIABLE);
     600              : 
     601              : Derivative &
     602         1049 : Derivative::operator=(const Tensor & val)
     603              : {
     604         1049 :   if (!_deriv->defined())
     605         1047 :     *_deriv = val.base_reshape(_base_sizes);
     606              :   else
     607            2 :     *_deriv = *_deriv + val.base_reshape(_base_sizes);
     608         1049 :   return *this;
     609              : }
     610              : }
        

Generated by: LCOV version 2.0-1