LCOV - code coverage report
Current view: top level - models - Variable.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 81.8 % 247 202
Test Date: 2025-06-29 01:25:44 Functions: 32.4 % 284 92

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/models/Variable.h"
      26              : #include "neml2/models/Model.h"
      27              : #include "neml2/models/DependencyResolver.h"
      28              : #include "neml2/models/map_types.h"
      29              : #include "neml2/tensors/tensors.h"
      30              : #include "neml2/tensors/assertions.h"
      31              : #include "neml2/tensors/functions/bmm.h"
      32              : #include "neml2/jit/utils.h"
      33              : #include "neml2/jit/TraceableTensorShape.h"
      34              : 
      35              : namespace neml2
      36              : {
      37         1579 : VariableBase::VariableBase(VariableName name_in, Model * owner, TensorShapeRef list_shape)
      38         1579 :   : _name(std::move(name_in)),
      39         1579 :     _owner(owner),
      40         3158 :     _list_sizes(list_shape)
      41              : {
      42         1579 : }
      43              : 
      44              : const Model &
      45        11906 : VariableBase::owner() const
      46              : {
      47        11906 :   neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
      48        11906 :   return *_owner;
      49              : }
      50              : 
      51              : Model &
      52        10043 : VariableBase::owner()
      53              : {
      54        10043 :   neml_assert_dbg(_owner, "Owner of variable '", name(), "' has not been defined.");
      55        10043 :   return *_owner;
      56              : }
      57              : 
      58              : bool
      59          826 : VariableBase::is_state() const
      60              : {
      61          826 :   return _name.is_state();
      62              : }
      63              : 
      64              : bool
      65          163 : VariableBase::is_old_state() const
      66              : {
      67          163 :   return _name.is_old_state();
      68              : }
      69              : 
      70              : bool
      71          167 : VariableBase::is_force() const
      72              : {
      73          167 :   return _name.is_force();
      74              : }
      75              : 
      76              : bool
      77           47 : VariableBase::is_old_force() const
      78              : {
      79           47 :   return _name.is_old_force();
      80              : }
      81              : 
      82              : bool
      83          134 : VariableBase::is_residual() const
      84              : {
      85          134 :   return _name.is_residual();
      86              : }
      87              : 
      88              : bool
      89           88 : VariableBase::is_parameter() const
      90              : {
      91           88 :   return _name.is_parameter();
      92              : }
      93              : 
      94              : bool
      95          388 : VariableBase::is_solve_dependent() const
      96              : {
      97          388 :   return is_state() || is_residual() || is_parameter();
      98              : }
      99              : 
     100              : bool
     101         1427 : VariableBase::is_dependent() const
     102              : {
     103         1427 :   return !currently_solving_nonlinear_system() || is_solve_dependent();
     104              : }
     105              : 
     106              : TensorOptions
     107          304 : VariableBase::options() const
     108              : {
     109          304 :   return tensor().options();
     110              : }
     111              : 
     112              : Dtype
     113           46 : VariableBase::scalar_type() const
     114              : {
     115           46 :   return tensor().scalar_type();
     116              : }
     117              : 
     118              : Device
     119            0 : VariableBase::device() const
     120              : {
     121            0 :   return tensor().device();
     122              : }
     123              : 
     124              : Size
     125            0 : VariableBase::dim() const
     126              : {
     127            0 :   return tensor().dim();
     128              : }
     129              : 
     130              : TensorShapeRef
     131            0 : VariableBase::sizes() const
     132              : {
     133            0 :   return tensor().sizes();
     134              : }
     135              : 
     136              : Size
     137            0 : VariableBase::size(Size dim) const
     138              : {
     139            0 :   return tensor().size(dim);
     140              : }
     141              : 
     142              : bool
     143            0 : VariableBase::batched() const
     144              : {
     145            0 :   return tensor().batched();
     146              : }
     147              : 
     148              : Size
     149        35906 : VariableBase::batch_dim() const
     150              : {
     151        35906 :   return tensor().batch_dim();
     152              : }
     153              : 
     154              : Size
     155       632828 : VariableBase::list_dim() const
     156              : {
     157       632828 :   return Size(list_sizes().size());
     158              : }
     159              : 
     160              : Size
     161            0 : VariableBase::base_dim() const
     162              : {
     163            0 :   return Size(base_sizes().size());
     164              : }
     165              : 
     166              : TraceableTensorShape
     167            1 : VariableBase::batch_sizes() const
     168              : {
     169            1 :   return tensor().batch_sizes();
     170              : }
     171              : 
     172              : TensorShapeRef
     173       684232 : VariableBase::list_sizes() const
     174              : {
     175       684232 :   return _list_sizes;
     176              : }
     177              : 
     178              : TraceableSize
     179            0 : VariableBase::batch_size(Size dim) const
     180              : {
     181            0 :   return tensor().batch_size(dim);
     182              : }
     183              : 
     184              : Size
     185            0 : VariableBase::base_size(Size dim) const
     186              : {
     187            0 :   return base_sizes()[dim];
     188              : }
     189              : 
     190              : Size
     191            1 : VariableBase::list_size(Size dim) const
     192              : {
     193            1 :   return list_sizes()[dim];
     194              : }
     195              : 
     196              : Size
     197            0 : VariableBase::base_storage() const
     198              : {
     199            0 :   return utils::storage_size(base_sizes());
     200              : }
     201              : 
     202              : Size
     203         2585 : VariableBase::assembly_storage() const
     204              : {
     205         2585 :   return utils::storage_size(list_sizes()) * utils::storage_size(base_sizes());
     206              : }
     207              : 
     208              : bool
     209            0 : VariableBase::requires_grad() const
     210              : {
     211            0 :   return tensor().requires_grad();
     212              : }
     213              : 
     214              : Derivative
     215          745 : VariableBase::d(const VariableBase & var)
     216              : {
     217          745 :   neml_assert_dbg(owning(),
     218              :                   "Cannot assign derivative to a referencing variable '",
     219          745 :                   name(),
     220              :                   "' with respect to '",
     221          745 :                   var.name(),
     222              :                   "'.");
     223          745 :   return Derivative({assembly_storage(), var.assembly_storage()}, &_derivs[var.name()]);
     224              : }
     225              : 
     226              : Derivative
     227          206 : VariableBase::d(const VariableBase & var1, const VariableBase & var2)
     228              : {
     229          206 :   neml_assert_dbg(owning(),
     230              :                   "Cannot assign second derivative to a referencing variable '",
     231          206 :                   name(),
     232              :                   "' with respect to '",
     233          206 :                   var1.name(),
     234              :                   "' and '",
     235          206 :                   var2.name(),
     236              :                   "'.");
     237          824 :   return Derivative({assembly_storage(), var1.assembly_storage(), var2.assembly_storage()},
     238          206 :                     &_sec_derivs[var1.name()][var2.name()]);
     239              : }
     240              : 
     241              : void
     242            0 : VariableBase::request_AD(const VariableBase & u)
     243              : {
     244            0 :   owner().request_AD(*this, u);
     245            0 : }
     246              : 
     247              : void
     248            4 : VariableBase::request_AD(const std::vector<const VariableBase *> & us)
     249              : {
     250           18 :   for (const auto & u : us)
     251              :   {
     252           14 :     neml_assert(u, "Cannot request AD for a null variable.");
     253           14 :     owner().request_AD(*this, *u);
     254              :   }
     255            4 : }
     256              : 
     257              : void
     258            0 : VariableBase::request_AD(const VariableBase & u1, const VariableBase & u2)
     259              : {
     260            0 :   owner().request_AD(*this, u1, u2);
     261            0 : }
     262              : 
     263              : void
     264            3 : VariableBase::request_AD(const std::vector<const VariableBase *> & u1s,
     265              :                          const std::vector<const VariableBase *> & u2s)
     266              : {
     267           15 :   for (const auto & u1 : u1s)
     268           60 :     for (const auto & u2 : u2s)
     269              :     {
     270           48 :       neml_assert(u1, "Cannot request AD for a null variable.");
     271           48 :       neml_assert(u2, "Cannot request AD for a null variable.");
     272           48 :       owner().request_AD(*this, *u1, *u2);
     273              :     }
     274            3 : }
     275              : 
     276              : void
     277        18315 : VariableBase::clear()
     278              : {
     279        18315 :   neml_assert_dbg(owning(), "Cannot clear a referencing variable '", name(), "'.");
     280        18315 :   _derivs.clear();
     281        18315 :   _sec_derivs.clear();
     282        18315 : }
     283              : 
     284              : void
     285           75 : VariableBase::apply_chain_rule(const DependencyResolver<Model, VariableName> & dep)
     286              : {
     287           96 :   for (const auto & [model, var] : dep.outbound_items())
     288           96 :     if (var == name())
     289              :     {
     290           75 :       _derivs = total_derivatives(dep, model, var);
     291           75 :       return;
     292              :     }
     293              : }
     294              : 
     295              : void
     296           16 : VariableBase::apply_second_order_chain_rule(const DependencyResolver<Model, VariableName> & dep)
     297              : {
     298           16 :   for (const auto & [model, var] : dep.outbound_items())
     299           16 :     if (var == name())
     300              :     {
     301           16 :       _sec_derivs = total_second_derivatives(dep, model, var);
     302           16 :       return;
     303              :     }
     304              : }
     305              : 
     306              : void
     307          975 : assign_or_add(Tensor & dest, const Tensor & val)
     308              : {
     309          975 :   if (dest.defined())
     310           69 :     dest = dest + val;
     311              :   else
     312          906 :     dest = val;
     313          975 : }
     314              : 
     315              : ValueMap
     316          397 : VariableBase::total_derivatives(const DependencyResolver<Model, VariableName> & dep,
     317              :                                 Model * model,
     318              :                                 const VariableName & yvar) const
     319              : {
     320          397 :   ValueMap derivs;
     321              : 
     322         1056 :   for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
     323              :   {
     324          659 :     if (dep.inbound_items().count({model, uvar}))
     325          446 :       assign_or_add(derivs[uvar], dy_du);
     326              :     else
     327          426 :       for (const auto & depu : dep.item_providers().at({model, uvar}))
     328          611 :         for (const auto & [xvar, du_dx] : total_derivatives(dep, depu.parent, uvar))
     329          611 :           assign_or_add(derivs[xvar], bmm(dy_du, du_dx));
     330              :   }
     331              : 
     332          397 :   return derivs;
     333            0 : }
     334              : 
     335              : DerivMap
     336           48 : VariableBase::total_second_derivatives(const DependencyResolver<Model, VariableName> & dep,
     337              :                                        Model * model,
     338              :                                        const VariableName & yvar) const
     339              : {
     340           48 :   DerivMap sec_derivs;
     341              : 
     342           83 :   for (const auto & [u1var, d2y_du1] : model->output_variable(yvar).second_derivatives())
     343          131 :     for (const auto & [u2var, d2y_du1u2] : d2y_du1)
     344              :     {
     345           96 :       if (dep.inbound_items().count({model, u1var}) && dep.inbound_items().count({model, u2var}))
     346           25 :         assign_or_add(sec_derivs[u1var][u2var], d2y_du1u2);
     347           71 :       else if (dep.inbound_items().count({model, u1var}))
     348           38 :         for (const auto & depu2 : dep.item_providers().at({model, u2var}))
     349           38 :           for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
     350           19 :             assign_or_add(sec_derivs[u1var][x2var],
     351           95 :                           Tensor(at::einsum("...ijq,...qk", {d2y_du1u2, du2_dxk}),
     352           19 :                                  utils::broadcast_batch_dim(d2y_du1u2, du2_dxk)));
     353           52 :       else if (dep.inbound_items().count({model, u2var}))
     354           36 :         for (const auto & depu1 : dep.item_providers().at({model, u1var}))
     355           36 :           for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
     356           18 :             assign_or_add(sec_derivs[x1var][u2var],
     357           90 :                           Tensor(at::einsum("...ipk,...pj", {d2y_du1u2, du1_dxj}),
     358           18 :                                  utils::broadcast_batch_dim(d2y_du1u2, du1_dxj)));
     359              :       else
     360           68 :         for (const auto & depu1 : dep.item_providers().at({model, u1var}))
     361           72 :           for (const auto & [x1var, du1_dxj] : total_derivatives(dep, depu1.parent, u1var))
     362           76 :             for (const auto & depu2 : dep.item_providers().at({model, u2var}))
     363           84 :               for (const auto & [x2var, du2_dxk] : total_derivatives(dep, depu2.parent, u2var))
     364           46 :                 assign_or_add(
     365           46 :                     sec_derivs[x1var][x2var],
     366          276 :                     Tensor(at::einsum("...ipq,...pj,...qk", {d2y_du1u2, du1_dxj, du2_dxk}),
     367           72 :                            utils::broadcast_batch_dim(d2y_du1u2, du1_dxj, du2_dxk)));
     368              :     }
     369              : 
     370          124 :   for (const auto & [uvar, dy_du] : model->output_variable(yvar).derivatives())
     371           76 :     if (!dep.inbound_items().count({model, uvar}))
     372           64 :       for (const auto & depu : dep.item_providers().at({model, uvar}))
     373           47 :         for (const auto & [x1var, d2u_dx1] : total_second_derivatives(dep, depu.parent, uvar))
     374           38 :           for (const auto & [x2var, d2u_dx1x2] : d2u_dx1)
     375           23 :             assign_or_add(sec_derivs[x1var][x2var],
     376          115 :                           Tensor(at::einsum("...ip,...pjk", {dy_du, d2u_dx1x2}),
     377           32 :                                  utils::broadcast_batch_dim(dy_du, d2u_dx1x2)));
     378              : 
     379           48 :   return sec_derivs;
     380          106 : }
     381              : 
     382              : template <typename T>
     383              : TensorType
     384         1468 : Variable<T>::type() const
     385              : {
     386         1468 :   return TensorTypeEnum<T>::value;
     387              : }
     388              : 
     389              : template <typename T>
     390              : std::unique_ptr<VariableBase>
     391          477 : Variable<T>::clone(const VariableName & name, Model * owner) const
     392              : {
     393              :   if constexpr (std::is_same_v<T, Tensor>)
     394              :   {
     395              :     return std::move(std::make_unique<Variable<Tensor>>(
     396              :         name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes(), base_sizes()));
     397              :   }
     398              :   else
     399              :   {
     400         1431 :     return std::move(std::make_unique<Variable<T>>(
     401         1908 :         name.empty() ? this->name() : name, owner ? owner : _owner, list_sizes()));
     402              :   }
     403              : }
     404              : 
     405              : template <typename T>
     406              : void
     407          724 : Variable<T>::ref(const VariableBase & var, bool ref_is_mutable)
     408              : {
     409          724 :   neml_assert(!_ref || ref() == var.ref(),
     410              :               "Variable '",
     411          724 :               name(),
     412              :               "' cannot reference another variable '",
     413          724 :               var.name(),
     414              :               "' after it has been assigned a reference. \nThe "
     415              :               "existing reference '",
     416          724 :               ref()->name(),
     417              :               "' was declared by model '",
     418          724 :               ref()->owner().name(),
     419              :               "'. \nThe new reference is declared by model '",
     420          724 :               var.owner().name(),
     421              :               "'.");
     422          724 :   neml_assert(&var != this, "Variable '", name(), "' cannot reference itself.");
     423          724 :   neml_assert(var.ref() != this,
     424              :               "Variable '",
     425          724 :               name(),
     426              :               "' cannot reference a variable that is referencing itself.");
     427          724 :   const auto * var_ptr = dynamic_cast<const Variable<T> *>(var.ref());
     428          724 :   neml_assert(var_ptr,
     429              :               "Variable ",
     430          724 :               name(),
     431              :               " of type ",
     432          724 :               type(),
     433              :               " failed to reference another variable named ",
     434          724 :               var.name(),
     435              :               " of type ",
     436          724 :               var.type(),
     437              :               ": Dynamic cast failure.");
     438          724 :   _ref = var_ptr;
     439          724 :   _ref_is_mutable |= ref_is_mutable;
     440          724 : }
     441              : 
     442              : template <typename T>
     443              : void
     444        18315 : Variable<T>::zero(const TensorOptions & options)
     445              : {
     446        18315 :   if (owning())
     447              :   {
     448              :     if constexpr (std::is_same_v<T, Tensor>)
     449              :       _value = T::zeros(list_sizes(), base_sizes(), options);
     450              :     else
     451        18315 :       _value = T::zeros(list_sizes(), options);
     452              :   }
     453              :   else
     454              :   {
     455            0 :     neml_assert_dbg(_ref_is_mutable,
     456              :                     "Model '",
     457            0 :                     owner().name(),
     458              :                     "' is trying to zero a variable '",
     459            0 :                     name(),
     460              :                     "' declared by model '",
     461            0 :                     ref()->owner().name(),
     462              :                     "' , but the referenced variable is not mutable.");
     463              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     464            0 :     const_cast<VariableBase *>(ref())->zero(options);
     465              :   }
     466        18315 : }
     467              : 
     468              : template <typename T>
     469              : void
     470        19643 : Variable<T>::set(const Tensor & val)
     471              : {
     472        19643 :   if (owning())
     473        15006 :     _value = T(val.base_reshape(utils::add_shapes(list_sizes(), base_sizes())),
     474        30012 :                utils::add_traceable_shapes(val.batch_sizes(), list_sizes()));
     475              :   else
     476              :   {
     477         4637 :     neml_assert_dbg(_ref_is_mutable,
     478              :                     "Model '",
     479         4637 :                     owner().name(),
     480              :                     "' is trying to assign value to a variable '",
     481         4637 :                     name(),
     482              :                     "' declared by model '",
     483         4637 :                     ref()->owner().name(),
     484              :                     "' , but the referenced variable is not mutable.");
     485              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     486         4637 :     const_cast<VariableBase *>(ref())->set(val);
     487              :   }
     488        19643 : }
     489              : 
     490              : template <typename T>
     491              : void
     492        13558 : Variable<T>::set(const ATensor & val, bool force)
     493              : {
     494        13558 :   if (owning())
     495              :   {
     496              :     if constexpr (std::is_same_v<T, Tensor>)
     497              :       _value = T(val, val.dim() - base_dim());
     498              :     else
     499         8214 :       _value = T(val);
     500              :   }
     501              :   else
     502              :   {
     503         5344 :     neml_assert_dbg(_ref_is_mutable || force,
     504              :                     "Model '",
     505         5344 :                     owner().name(),
     506              :                     "' is trying to assign value to a variable '",
     507         5344 :                     name(),
     508              :                     "' declared by model '",
     509         5344 :                     ref()->owner().name(),
     510              :                     "' , but the referenced variable is not mutable.");
     511              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     512         5344 :     const_cast<VariableBase *>(ref())->set(val);
     513              :   }
     514        13558 : }
     515              : 
     516              : template <typename T>
     517              : Tensor
     518            0 : Variable<T>::get() const
     519              : {
     520            0 :   return tensor().base_flatten();
     521              : }
     522              : 
     523              : template <typename T>
     524              : Tensor
     525      1045983 : Variable<T>::tensor() const
     526              : {
     527      1045983 :   if (owning())
     528              :   {
     529       632828 :     neml_assert_dbg(_value.defined(), "Variable '", name(), "' has undefined value.");
     530       632828 :     auto batch_sizes = _value.batch_sizes().slice(0, _value.batch_dim() - list_dim());
     531       632828 :     return Tensor(_value, batch_sizes);
     532       632828 :   }
     533              : 
     534       413155 :   return ref()->tensor();
     535              : }
     536              : 
     537              : template <typename T>
     538              : void
     539            6 : Variable<T>::requires_grad_(bool req)
     540              : {
     541            6 :   if (owning())
     542            6 :     _value.requires_grad_(req);
     543              :   else
     544              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     545            0 :     const_cast<VariableBase *>(ref())->requires_grad_(req);
     546            6 : }
     547              : 
     548              : template <typename T>
     549              : void
     550          618 : Variable<T>::operator=(const Tensor & val)
     551              : {
     552          618 :   if (owning())
     553          618 :     _value = T(val);
     554              :   else
     555              :   {
     556            0 :     neml_assert_dbg(_ref_is_mutable,
     557              :                     "Model '",
     558            0 :                     owner().name(),
     559              :                     "' is trying to assign value to a variable '",
     560            0 :                     name(),
     561              :                     "' declared by model '",
     562            0 :                     ref()->owner().name(),
     563              :                     "' , but the referenced variable is not mutable.");
     564              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     565            0 :     *const_cast<VariableBase *>(ref()) = val;
     566              :   }
     567          618 : }
     568              : 
     569              : template <typename T>
     570              : void
     571        18315 : Variable<T>::clear()
     572              : {
     573        18315 :   if (owning())
     574              :   {
     575        18315 :     VariableBase::clear();
     576        18315 :     _value = T();
     577              :   }
     578              :   else
     579              :   {
     580            0 :     neml_assert_dbg(_ref_is_mutable,
     581              :                     "Model '",
     582            0 :                     owner().name(),
     583              :                     "' is trying to clear a variable '",
     584            0 :                     name(),
     585              :                     "' declared by model '",
     586            0 :                     ref()->owner().name(),
     587              :                     "' , but the referenced variable is not mutable.");
     588              :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     589            0 :     const_cast<VariableBase *>(ref())->clear();
     590              :   }
     591        18315 : }
     592              : 
     593              : #define INSTANTIATE_VARIABLE(T) template class Variable<T>
     594              : FOR_ALL_PRIMITIVETENSOR(INSTANTIATE_VARIABLE);
     595              : 
     596              : Derivative &
     597          951 : Derivative::operator=(const Tensor & val)
     598              : {
     599          951 :   *_deriv = val.base_reshape(_base_sizes);
     600          951 :   return *this;
     601              : }
     602              : }
        

Generated by: LCOV version 2.0-1