LCOV - code coverage report
Current view: top level - base - LabeledAxis.cxx (source / functions) Coverage Total Hit
Test: coverage.info Lines: 91.4 % 244 223
Test Date: 2025-10-02 16:03:03 Functions: 61.8 % 55 34

            Line data    Source code
       1              : // Copyright 2024, UChicago Argonne, LLC
       2              : // All Rights Reserved
       3              : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
       4              : // By: Argonne National Laboratory
       5              : // OPEN SOURCE LICENSE (MIT)
       6              : //
       7              : // Permission is hereby granted, free of charge, to any person obtaining a copy
       8              : // of this software and associated documentation files (the "Software"), to deal
       9              : // in the Software without restriction, including without limitation the rights
      10              : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      11              : // copies of the Software, and to permit persons to whom the Software is
      12              : // furnished to do so, subject to the following conditions:
      13              : //
      14              : // The above copyright notice and this permission notice shall be included in
      15              : // all copies or substantial portions of the Software.
      16              : //
      17              : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      18              : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      19              : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
      20              : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
      21              : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
      22              : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
      23              : // THE SOFTWARE.
      24              : 
      25              : #include "neml2/base/LabeledAxis.h"
      26              : #include "neml2/tensors/shape_utils.h"
      27              : #include "neml2/tensors/tensors.h"
      28              : #include "neml2/misc/assertions.h"
      29              : 
      30              : namespace neml2
      31              : {
      32         2368 : LabeledAxis::LabeledAxis(LabeledAxisAccessor prefix)
      33         2368 :   : _prefix(std::move(prefix))
      34              : {
      35         2368 : }
      36              : 
      37              : LabeledAxisAccessor
      38        24809 : LabeledAxis::qualify(const LabeledAxisAccessor & accessor) const
      39              : {
      40        24809 :   return accessor.prepend(_prefix);
      41              : }
      42              : 
      43              : LabeledAxis &
      44         2368 : LabeledAxis::add_subaxis(const std::string & name)
      45              : {
      46         2368 :   neml_assert(!_setup, "Cannot modify a sub-axis after the axis has been set up.");
      47         2368 :   neml_assert(
      48         2368 :       _variables.count(name) == 0, "Cannot add a subaxis with the same name as a variable: ", name);
      49         2368 :   auto [subaxis, success] =
      50         2368 :       _subaxes.emplace(name, std::make_shared<LabeledAxis>(_prefix.append(name)));
      51         2368 :   if (success)
      52         1701 :     cache_reserved_subaxis(name);
      53         4736 :   return *(subaxis->second);
      54              : }
      55              : 
      56              : void
      57         4365 : LabeledAxis::add_variable(const LabeledAxisAccessor & name, Size sz)
      58              : {
      59         4365 :   neml_assert(!_setup, "Cannot modify a sub-axis after the axis has been set up.");
      60         4365 :   neml_assert(!name.empty(), "Cannot add a variable with empty name.");
      61              : 
      62         4365 :   if (name.size() == 1)
      63              :   {
      64         2123 :     neml_assert(_variables.count(name[0]) == 0 && _subaxes.count(name[0]) == 0,
      65              :                 "Cannot add a variable with the same name as an existing variable or a sub-axis: '",
      66         2123 :                 name[0],
      67              :                 "'");
      68         2123 :     _variables.emplace(name[0], sz);
      69              :   }
      70              :   else
      71         2242 :     add_subaxis(name[0]).add_variable(name.slice(1), sz);
      72         4365 : }
      73              : 
      74              : template <typename T>
      75              : void
      76          317 : LabeledAxis::add_variable(const LabeledAxisAccessor & name)
      77              : {
      78          317 :   auto sz = utils::storage_size(T::const_base_sizes);
      79          317 :   add_variable(name, sz);
      80          317 : }
      81              : #define INSTANTIATE_ADD_VARIABLE(T)                                                                \
      82              :   template void LabeledAxis::add_variable<T>(const LabeledAxisAccessor &)
      83              : FOR_ALL_PRIMITIVETENSOR(INSTANTIATE_ADD_VARIABLE);
      84              : 
      85              : void
      86         2626 : LabeledAxis::setup_layout()
      87              : {
      88              :   // Clear internal data that may have been constructed from previous setup_layout calls
      89         2626 :   _size = 0;
      90              : 
      91         2626 :   _variable_to_id_map.clear();
      92         2626 :   _id_to_variable_map.clear();
      93         2626 :   _id_to_variable_size_map.clear();
      94         2626 :   _id_to_variable_slice_map.clear();
      95              : 
      96         2626 :   _sorted_subaxes.clear();
      97         2626 :   _subaxis_to_id_map.clear();
      98         2626 :   _id_to_subaxis_map.clear();
      99         2626 :   _id_to_subaxis_size_map.clear();
     100         2626 :   _id_to_subaxis_slice_map.clear();
     101              : 
     102              :   // Set up variable assembly IDs and slicing indices
     103         4581 :   for (auto & [name, sz] : _variables)
     104              :   {
     105         1955 :     _variable_to_id_map.emplace(name, _variable_to_id_map.size());
     106         1955 :     _id_to_variable_map.emplace_back(name);
     107         1955 :     _id_to_variable_size_map.push_back(sz);
     108         1955 :     _id_to_variable_slice_map.emplace_back(_size, _size + sz);
     109         1955 :     _size += sz;
     110              :   }
     111              : 
     112              :   // Set up subaxes
     113         4264 :   for (auto & [name, axis] : _subaxes)
     114              :   {
     115         1638 :     axis->setup_layout();
     116         1638 :     auto sz = axis->size();
     117         1638 :     _sorted_subaxes.push_back(axis.get());
     118         1638 :     _subaxis_to_id_map.emplace(name, _subaxis_to_id_map.size());
     119         1638 :     _id_to_subaxis_map.push_back(name);
     120         1638 :     _id_to_subaxis_size_map.push_back(sz);
     121         1638 :     _id_to_subaxis_slice_map.emplace_back(_size, _size + sz);
     122         1638 :     _size += sz;
     123              : 
     124              :     // Merge variable maps
     125         4090 :     for (const auto & var_name : axis->_id_to_variable_map)
     126              :     {
     127         2452 :       auto var_id = axis->_variable_to_id_map.at(var_name);
     128         2452 :       auto full_name = var_name.prepend(name);
     129         2452 :       _variable_to_id_map.emplace(full_name, _variable_to_id_map.size());
     130         2452 :       _id_to_variable_map.push_back(full_name);
     131         2452 :       _id_to_variable_size_map.push_back(axis->_id_to_variable_size_map[var_id]);
     132              : 
     133              :       // Slice is relative to the sub-axis, so we need to shift it
     134         2452 :       const auto & slice = axis->_id_to_variable_slice_map[var_id];
     135         2452 :       auto offset = _id_to_subaxis_slice_map.back().first;
     136         2452 :       auto new_slice = std::pair<Size, Size>{slice.first + offset, slice.second + offset};
     137         2452 :       _id_to_variable_slice_map.push_back(new_slice);
     138         2452 :     }
     139              :   }
     140              : 
     141              :   // Finished set up
     142         2626 :   _setup = true;
     143         2626 : }
     144              : 
     145              : Size
     146         1675 : LabeledAxis::size() const
     147              : {
     148              :   // If the axis has been set up, return the cached size
     149         1675 :   if (_setup)
     150         1652 :     return _size;
     151              : 
     152              :   // Otherwise, calculate the size
     153           23 :   Size sz = 0;
     154           69 :   for (const auto & [name, var_sz] : _variables)
     155           46 :     sz += var_sz;
     156           32 :   for (const auto & [name, axis] : _subaxes)
     157            9 :     sz += axis->size();
     158           23 :   return sz;
     159              : }
     160              : 
     161              : Size
     162           78 : LabeledAxis::size(const LabeledAxisAccessor & name) const
     163              : {
     164           78 :   neml_assert(!name.empty(), "Cannot get the size of an item with an empty name.");
     165              : 
     166              :   // If the name has length 1, it must be a variable or a local sub-axis
     167           78 :   if (name.size() == 1)
     168              :   {
     169           46 :     const auto var = _variables.find(name[0]);
     170           46 :     if (var != _variables.end())
     171           36 :       return var->second;
     172              : 
     173           10 :     const auto subaxis = _subaxes.find(name[0]);
     174           10 :     neml_assert(subaxis != _subaxes.end(),
     175              :                 "Item named '",
     176              :                 name,
     177              :                 "' is neither a variable nor a local sub-axis on axis:\n",
     178              :                 *this);
     179           10 :     return subaxis->second->size();
     180              :   }
     181              : 
     182              :   // Otherwise, the item must be on a sub-axis
     183           32 :   const auto subaxis = _subaxes.find(name[0]);
     184           32 :   neml_assert(subaxis != _subaxes.end(),
     185              :               "Item named '",
     186              :               name,
     187              :               "' is neither a variable nor a sub-axis on axis:\n",
     188              :               *this);
     189           32 :   return subaxis->second->size(name.slice(1));
     190              : }
     191              : 
     192              : indexing::Slice
     193           27 : LabeledAxis::slice(const LabeledAxisAccessor & name) const
     194              : {
     195           27 :   ensure_setup_dbg();
     196           27 :   neml_assert(!name.empty(), "Cannot get the slice of an item with an empty name.");
     197              : 
     198              :   // If the name is a variable
     199           27 :   if (has_variable(name))
     200              :   {
     201           22 :     auto s = variable_slice(name);
     202           22 :     return {s.first, s.second};
     203              :   }
     204              : 
     205              :   // Otherwise, the name must be a sub-axis
     206            5 :   neml_assert_dbg(has_subaxis(name[0]),
     207              :                   "Item named '",
     208              :                   name,
     209              :                   "' is neither a variable nor a sub-axis on axis:\n",
     210              :                   *this);
     211            5 :   auto s = subaxis_slice(name);
     212            5 :   return {s.first, s.second};
     213              : }
     214              : 
     215              : std::size_t
     216          478 : LabeledAxis::nvariable() const
     217              : {
     218              :   // If axis has been set up, return the cached number of variables
     219          478 :   if (_setup)
     220          469 :     return _id_to_variable_map.size();
     221              : 
     222              :   // Otherwise, calculate the number of variables
     223            9 :   std::size_t nvar = _variables.size();
     224           14 :   for (const auto & [name, axis] : _subaxes)
     225            5 :     nvar += axis->nvariable();
     226            9 :   return nvar;
     227              : }
     228              : 
     229              : bool
     230        16950 : LabeledAxis::has_variable(const LabeledAxisAccessor & name) const
     231              : {
     232        16950 :   neml_assert(!name.empty(), "Variable name cannot be empty.");
     233              : 
     234              :   // If axis has been set up, return the cached existence
     235        16950 :   if (_setup)
     236        15971 :     return std::find(_id_to_variable_map.begin(), _id_to_variable_map.end(), name) !=
     237        31942 :            _id_to_variable_map.end();
     238              : 
     239              :   // Otherwise, check the existence of the variable
     240          979 :   if (name.size() == 1)
     241          361 :     return _variables.find(name[0]) != _variables.end();
     242              : 
     243          618 :   const auto subaxis = _subaxes.find(name[0]);
     244          618 :   return subaxis != _subaxes.end() && subaxis->second->has_variable(name.slice(1));
     245              : }
     246              : 
     247              : std::size_t
     248          102 : LabeledAxis::variable_id(const LabeledAxisAccessor & name) const
     249              : {
     250          102 :   ensure_setup_dbg();
     251          102 :   neml_assert(!name.empty(), "Cannot get the ID of a variable with an empty name.");
     252          102 :   const auto id = _variable_to_id_map.find(name);
     253          102 :   neml_assert(id != _variable_to_id_map.end(),
     254              :               "Variable named '",
     255              :               name,
     256              :               "' does not exist on axis:\n",
     257              :               *this);
     258          204 :   return id->second;
     259              : }
     260              : 
     261              : const std::vector<LabeledAxisAccessor> &
     262        28972 : LabeledAxis::variable_names() const
     263              : {
     264        28972 :   ensure_setup_dbg();
     265        28972 :   return _id_to_variable_map;
     266              : }
     267              : 
     268              : const std::vector<std::pair<Size, Size>> &
     269           16 : LabeledAxis::variable_slices() const
     270              : {
     271           16 :   ensure_setup_dbg();
     272           16 :   return _id_to_variable_slice_map;
     273              : }
     274              : 
     275              : const std::pair<Size, Size> &
     276           40 : LabeledAxis::variable_slice(const LabeledAxisAccessor & name) const
     277              : {
     278           40 :   ensure_setup_dbg();
     279           40 :   return _id_to_variable_slice_map.at(variable_id(name));
     280              : }
     281              : 
     282              : const std::vector<Size> &
     283        56794 : LabeledAxis::variable_sizes() const
     284              : {
     285        56794 :   ensure_setup_dbg();
     286        56794 :   return _id_to_variable_size_map;
     287              : }
     288              : 
     289              : Size
     290           76 : LabeledAxis::variable_size(const LabeledAxisAccessor & name) const
     291              : {
     292              :   // If axis has been set up, return the cached variable size
     293           76 :   if (_setup)
     294           44 :     return _id_to_variable_size_map[variable_id(name)];
     295              : 
     296              :   // Otherwise, calculate the variable size
     297           32 :   if (name.size() == 1)
     298              :   {
     299           18 :     const auto var = _variables.find(name[0]);
     300           18 :     neml_assert(
     301           18 :         var != _variables.end(), "Variable named '", name, "' does not exist on axis:\n", *this);
     302           18 :     return var->second;
     303              :   }
     304              : 
     305           14 :   const auto subaxis = _subaxes.find(name[0]);
     306           14 :   neml_assert(
     307           14 :       subaxis != _subaxes.end(), "Variable named '", name, "' does not exist on axis:\n", *this);
     308           14 :   return subaxis->second->variable_size(name.slice(1));
     309              : }
     310              : 
     311              : std::size_t
     312           13 : LabeledAxis::nsubaxis() const
     313              : {
     314           13 :   return _subaxes.size();
     315              : }
     316              : 
     317              : bool
     318           27 : LabeledAxis::has_subaxis(const LabeledAxisAccessor & name) const
     319              : {
     320           27 :   neml_assert(!name.empty(), "Sub-axis name cannot be empty.");
     321              : 
     322           27 :   const auto subaxis = _subaxes.find(name[0]);
     323              : 
     324           27 :   if (name.size() == 1)
     325           23 :     return subaxis != _subaxes.end();
     326              : 
     327            4 :   return subaxis->second->has_subaxis(name.slice(1));
     328              : }
     329              : 
     330              : std::size_t
     331           17 : LabeledAxis::subaxis_id(const std::string & name) const
     332              : {
     333           17 :   ensure_setup_dbg();
     334           17 :   const auto id = _subaxis_to_id_map.find(name);
     335           17 :   neml_assert(id != _subaxis_to_id_map.end(),
     336              :               "Sub-axis named '",
     337              :               name,
     338              :               "' does not exist on axis:\n",
     339              :               *this);
     340           34 :   return id->second;
     341              : }
     342              : 
     343              : const std::vector<const LabeledAxis *> &
     344            2 : LabeledAxis::subaxes() const
     345              : {
     346            2 :   ensure_setup_dbg();
     347            2 :   return _sorted_subaxes;
     348              : }
     349              : 
     350              : const LabeledAxis &
     351         6693 : LabeledAxis::subaxis(const LabeledAxisAccessor & name) const
     352              : {
     353         6693 :   neml_assert(!name.empty(), "Sub-axis name cannot be empty.");
     354              : 
     355         6693 :   const auto subaxis = _subaxes.find(name[0]);
     356         6693 :   neml_assert(
     357         6693 :       subaxis != _subaxes.end(), "Sub-axis named '", name, "' does not exist on axis:\n", *this);
     358              : 
     359         6693 :   if (name.size() == 1)
     360         6689 :     return *subaxis->second;
     361              : 
     362            4 :   return subaxis->second->subaxis(name.slice(1));
     363              : }
     364              : 
     365              : LabeledAxis &
     366         6605 : LabeledAxis::subaxis(const LabeledAxisAccessor & name)
     367              : {
     368              :   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
     369         6605 :   return const_cast<LabeledAxis &>(std::as_const(*this).subaxis(name));
     370              : }
     371              : 
     372              : const std::vector<std::string> &
     373            9 : LabeledAxis::subaxis_names() const
     374              : {
     375            9 :   ensure_setup_dbg();
     376            9 :   return _id_to_subaxis_map;
     377              : }
     378              : 
     379              : const std::vector<std::pair<Size, Size>> &
     380            3 : LabeledAxis::subaxis_slices() const
     381              : {
     382            3 :   ensure_setup_dbg();
     383            3 :   return _id_to_subaxis_slice_map;
     384              : }
     385              : 
     386              : std::pair<Size, Size>
     387           14 : LabeledAxis::subaxis_slice(const LabeledAxisAccessor & name) const
     388              : {
     389           14 :   ensure_setup_dbg();
     390              : 
     391              :   // If the name has length 1, it must be a local sub-axis
     392           14 :   if (name.size() == 1)
     393           10 :     return _id_to_subaxis_slice_map[subaxis_id(name[0])];
     394              : 
     395              :   // Otherwise, the name must be on a sub-axis
     396            4 :   const auto subaxis = _subaxes.find(name[0]);
     397            4 :   neml_assert(
     398            4 :       subaxis != _subaxes.end(), "Sub-axis named '", name, "' does not exist on axis:\n", *this);
     399            4 :   const auto & slice = subaxis->second->subaxis_slice(name.slice(1));
     400            4 :   auto offset = _id_to_subaxis_slice_map[subaxis_id(name[0])].first;
     401            4 :   return {slice.first + offset, slice.second + offset};
     402              : }
     403              : 
     404              : const std::vector<Size> &
     405            9 : LabeledAxis::subaxis_sizes() const
     406              : {
     407            9 :   ensure_setup_dbg();
     408            9 :   return _id_to_subaxis_size_map;
     409              : }
     410              : 
     411              : Size
     412           14 : LabeledAxis::subaxis_size(const LabeledAxisAccessor & name) const
     413              : {
     414           14 :   const auto subaxis = _subaxes.find(name[0]);
     415           14 :   neml_assert(
     416           14 :       subaxis != _subaxes.end(), "Sub-axis named '", name, "' does not exist on axis:\n", *this);
     417              : 
     418           14 :   if (name.size() == 1)
     419           10 :     return subaxis->second->size();
     420              : 
     421            4 :   return subaxis->second->subaxis_size(name.slice(1));
     422              : }
     423              : 
     424              : bool
     425           10 : LabeledAxis::equals(const LabeledAxis & other) const
     426              : {
     427              :   // They must have the same set of variables (with the same storage sizes)
     428           10 :   if (_variables != other._variables)
     429            2 :     return false;
     430              : 
     431              :   // They must have the same number of subaxes
     432            8 :   if (_subaxes.size() != other._subaxes.size())
     433            0 :     return false;
     434              : 
     435              :   // Compare each subaxis
     436           10 :   for (const auto & [name, axis] : _subaxes)
     437              :   {
     438            2 :     if (other._subaxes.count(name) == 0)
     439            0 :       return false;
     440              : 
     441            2 :     if (*other._subaxes.at(name) != *axis)
     442            0 :       return false;
     443              :   }
     444              : 
     445            8 :   return true;
     446              : }
     447              : 
     448              : void
     449         1701 : LabeledAxis::cache_reserved_subaxis(const std::string & axis_name)
     450              : {
     451         1701 :   if (axis_name == STATE)
     452          786 :     _has_state = true;
     453          915 :   else if (axis_name == OLD_STATE)
     454           66 :     _has_old_state = true;
     455          849 :   else if (axis_name == FORCES)
     456          162 :     _has_forces = true;
     457          687 :   else if (axis_name == OLD_FORCES)
     458           75 :     _has_old_forces = true;
     459          612 :   else if (axis_name == RESIDUAL)
     460           53 :     _has_residual = true;
     461          559 :   else if (axis_name == PARAMETERS)
     462           57 :     _has_parameters = true;
     463         1701 : }
     464              : 
     465              : void
     466        86005 : LabeledAxis::ensure_setup_dbg() const
     467              : {
     468        86005 :   neml_assert_dbg(_setup, "The axis has not been setup yet.");
     469        86005 : }
     470              : 
     471              : std::ostream &
     472            0 : operator<<(std::ostream & os, const LabeledAxis & axis)
     473              : {
     474              :   // Get unqualified variable names
     475            0 :   const auto var_names = axis.variable_names();
     476              : 
     477              :   // Find the maximum variable name length
     478            0 :   size_t max_var_name_length = 0;
     479            0 :   for (const auto & var_name : var_names)
     480              :   {
     481            0 :     const auto var_name_str = utils::stringify(var_name);
     482            0 :     if (var_name_str.size() > max_var_name_length)
     483            0 :       max_var_name_length = var_name_str.size();
     484            0 :   }
     485              : 
     486              :   // Print variables with right alignment
     487            0 :   for (auto var = var_names.begin(); var != var_names.end(); var++)
     488              :   {
     489            0 :     if (axis._setup)
     490            0 :       os << std::setw(3) << std::right << axis.variable_id(*var) << ": ";
     491            0 :     os << std::setw(int(max_var_name_length)) << std::left << utils::stringify(*var);
     492            0 :     if (axis._setup)
     493            0 :       os << " [" << axis.variable_slice(*var) << "]";
     494            0 :     if (std::next(var) != var_names.end())
     495            0 :       os << std::endl;
     496              :   }
     497              : 
     498            0 :   return os;
     499            0 : }
     500              : 
     501              : bool
     502            6 : operator==(const LabeledAxis & a, const LabeledAxis & b)
     503              : {
     504            6 :   return a.equals(b);
     505              : }
     506              : 
     507              : bool
     508            4 : operator!=(const LabeledAxis & a, const LabeledAxis & b)
     509              : {
     510            4 :   return !a.equals(b);
     511              : }
     512              : } // namespace neml2
        

Generated by: LCOV version 2.0-1