Line data Source code
1 : // Copyright 2024, UChicago Argonne, LLC
2 : // All Rights Reserved
3 : // Software Name: NEML2 -- the New Engineering material Model Library, version 2
4 : // By: Argonne National Laboratory
5 : // OPEN SOURCE LICENSE (MIT)
6 : //
7 : // Permission is hereby granted, free of charge, to any person obtaining a copy
8 : // of this software and associated documentation files (the "Software"), to deal
9 : // in the Software without restriction, including without limitation the rights
10 : // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 : // copies of the Software, and to permit persons to whom the Software is
12 : // furnished to do so, subject to the following conditions:
13 : //
14 : // The above copyright notice and this permission notice shall be included in
15 : // all copies or substantial portions of the Software.
16 : //
17 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 : // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 : // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 : // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 : // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 : // THE SOFTWARE.
24 :
25 : #include "neml2/base/LabeledAxis.h"
26 : #include "neml2/tensors/shape_utils.h"
27 : #include "neml2/tensors/tensors.h"
28 : #include "neml2/misc/assertions.h"
29 :
30 : namespace neml2
31 : {
32 2144 : LabeledAxis::LabeledAxis(LabeledAxisAccessor prefix)
33 2144 : : _prefix(std::move(prefix))
34 : {
35 2144 : }
36 :
37 : LabeledAxisAccessor
38 24665 : LabeledAxis::qualify(const LabeledAxisAccessor & accessor) const
39 : {
40 24665 : return accessor.prepend(_prefix);
41 : }
42 :
43 : LabeledAxis &
44 2144 : LabeledAxis::add_subaxis(const std::string & name)
45 : {
46 2144 : neml_assert(!_setup, "Cannot modify a sub-axis after the axis has been set up.");
47 2144 : neml_assert(
48 2144 : _variables.count(name) == 0, "Cannot add a subaxis with the same name as a variable: ", name);
49 2144 : auto [subaxis, success] =
50 2144 : _subaxes.emplace(name, std::make_shared<LabeledAxis>(_prefix.append(name)));
51 2144 : if (success)
52 1550 : cache_reserved_subaxis(name);
53 4288 : return *(subaxis->second);
54 : }
55 :
56 : void
57 3956 : LabeledAxis::add_variable(const LabeledAxisAccessor & name, Size sz)
58 : {
59 3956 : neml_assert(!_setup, "Cannot modify a sub-axis after the axis has been set up.");
60 3956 : neml_assert(!name.empty(), "Cannot add a variable with empty name.");
61 :
62 3956 : if (name.size() == 1)
63 : {
64 1938 : neml_assert(_variables.count(name[0]) == 0 && _subaxes.count(name[0]) == 0,
65 : "Cannot add a variable with the same name as an existing variable or a sub-axis: '",
66 1938 : name[0],
67 : "'");
68 1938 : _variables.emplace(name[0], sz);
69 : }
70 : else
71 2018 : add_subaxis(name[0]).add_variable(name.slice(1), sz);
72 3956 : }
73 :
74 : template <typename T>
75 : void
76 317 : LabeledAxis::add_variable(const LabeledAxisAccessor & name)
77 : {
78 317 : auto sz = utils::storage_size(T::const_base_sizes);
79 317 : add_variable(name, sz);
80 317 : }
81 : #define INSTANTIATE_ADD_VARIABLE(T) \
82 : template void LabeledAxis::add_variable<T>(const LabeledAxisAccessor &)
83 : FOR_ALL_PRIMITIVETENSOR(INSTANTIATE_ADD_VARIABLE);
84 :
85 : void
86 2377 : LabeledAxis::setup_layout()
87 : {
88 : // Clear internal data that may have been constructed from previous setup_layout calls
89 2377 : _size = 0;
90 :
91 2377 : _variable_to_id_map.clear();
92 2377 : _id_to_variable_map.clear();
93 2377 : _id_to_variable_size_map.clear();
94 2377 : _id_to_variable_slice_map.clear();
95 :
96 2377 : _sorted_subaxes.clear();
97 2377 : _subaxis_to_id_map.clear();
98 2377 : _id_to_subaxis_map.clear();
99 2377 : _id_to_subaxis_size_map.clear();
100 2377 : _id_to_subaxis_slice_map.clear();
101 :
102 : // Set up variable assembly IDs and slicing indices
103 4147 : for (auto & [name, sz] : _variables)
104 : {
105 1770 : _variable_to_id_map.emplace(name, _variable_to_id_map.size());
106 1770 : _id_to_variable_map.emplace_back(name);
107 1770 : _id_to_variable_size_map.push_back(sz);
108 1770 : _id_to_variable_slice_map.emplace_back(_size, _size + sz);
109 1770 : _size += sz;
110 : }
111 :
112 : // Set up subaxes
113 3864 : for (auto & [name, axis] : _subaxes)
114 : {
115 1487 : axis->setup_layout();
116 1487 : auto sz = axis->size();
117 1487 : _sorted_subaxes.push_back(axis.get());
118 1487 : _subaxis_to_id_map.emplace(name, _subaxis_to_id_map.size());
119 1487 : _id_to_subaxis_map.push_back(name);
120 1487 : _id_to_subaxis_size_map.push_back(sz);
121 1487 : _id_to_subaxis_slice_map.emplace_back(_size, _size + sz);
122 1487 : _size += sz;
123 :
124 : // Merge variable maps
125 3715 : for (const auto & var_name : axis->_id_to_variable_map)
126 : {
127 2228 : auto var_id = axis->_variable_to_id_map.at(var_name);
128 2228 : auto full_name = var_name.prepend(name);
129 2228 : _variable_to_id_map.emplace(full_name, _variable_to_id_map.size());
130 2228 : _id_to_variable_map.push_back(full_name);
131 2228 : _id_to_variable_size_map.push_back(axis->_id_to_variable_size_map[var_id]);
132 :
133 : // Slice is relative to the sub-axis, so we need to shift it
134 2228 : const auto & slice = axis->_id_to_variable_slice_map[var_id];
135 2228 : auto offset = _id_to_subaxis_slice_map.back().first;
136 2228 : auto new_slice = std::pair<Size, Size>{slice.first + offset, slice.second + offset};
137 2228 : _id_to_variable_slice_map.push_back(new_slice);
138 2228 : }
139 : }
140 :
141 : // Finished set up
142 2377 : _setup = true;
143 2377 : }
144 :
145 : Size
146 1524 : LabeledAxis::size() const
147 : {
148 : // If the axis has been set up, return the cached size
149 1524 : if (_setup)
150 1501 : return _size;
151 :
152 : // Otherwise, calculate the size
153 23 : Size sz = 0;
154 69 : for (const auto & [name, var_sz] : _variables)
155 46 : sz += var_sz;
156 32 : for (const auto & [name, axis] : _subaxes)
157 9 : sz += axis->size();
158 23 : return sz;
159 : }
160 :
161 : Size
162 78 : LabeledAxis::size(const LabeledAxisAccessor & name) const
163 : {
164 78 : neml_assert(!name.empty(), "Cannot get the size of an item with an empty name.");
165 :
166 : // If the name has length 1, it must be a variable or a local sub-axis
167 78 : if (name.size() == 1)
168 : {
169 46 : const auto var = _variables.find(name[0]);
170 46 : if (var != _variables.end())
171 36 : return var->second;
172 :
173 10 : const auto subaxis = _subaxes.find(name[0]);
174 10 : neml_assert(subaxis != _subaxes.end(),
175 : "Item named '",
176 : name,
177 : "' is neither a variable nor a local sub-axis on axis:\n",
178 : *this);
179 10 : return subaxis->second->size();
180 : }
181 :
182 : // Otherwise, the item must be on a sub-axis
183 32 : const auto subaxis = _subaxes.find(name[0]);
184 32 : neml_assert(subaxis != _subaxes.end(),
185 : "Item named '",
186 : name,
187 : "' is neither a variable nor a sub-axis on axis:\n",
188 : *this);
189 32 : return subaxis->second->size(name.slice(1));
190 : }
191 :
192 : indexing::Slice
193 27 : LabeledAxis::slice(const LabeledAxisAccessor & name) const
194 : {
195 27 : ensure_setup_dbg();
196 27 : neml_assert(!name.empty(), "Cannot get the slice of an item with an empty name.");
197 :
198 : // If the name is a variable
199 27 : if (has_variable(name))
200 : {
201 22 : auto s = variable_slice(name);
202 22 : return {s.first, s.second};
203 : }
204 :
205 : // Otherwise, the name must be a sub-axis
206 5 : neml_assert_dbg(has_subaxis(name[0]),
207 : "Item named '",
208 : name,
209 : "' is neither a variable nor a sub-axis on axis:\n",
210 : *this);
211 5 : auto s = subaxis_slice(name);
212 5 : return {s.first, s.second};
213 : }
214 :
215 : std::size_t
216 416 : LabeledAxis::nvariable() const
217 : {
218 : // If axis has been set up, return the cached number of variables
219 416 : if (_setup)
220 407 : return _id_to_variable_map.size();
221 :
222 : // Otherwise, calculate the number of variables
223 9 : std::size_t nvar = _variables.size();
224 14 : for (const auto & [name, axis] : _subaxes)
225 5 : nvar += axis->nvariable();
226 9 : return nvar;
227 : }
228 :
229 : bool
230 15876 : LabeledAxis::has_variable(const LabeledAxisAccessor & name) const
231 : {
232 15876 : neml_assert(!name.empty(), "Variable name cannot be empty.");
233 :
234 : // If axis has been set up, return the cached existence
235 15876 : if (_setup)
236 15128 : return std::find(_id_to_variable_map.begin(), _id_to_variable_map.end(), name) !=
237 30256 : _id_to_variable_map.end();
238 :
239 : // Otherwise, check the existence of the variable
240 748 : if (name.size() == 1)
241 273 : return _variables.find(name[0]) != _variables.end();
242 :
243 475 : const auto subaxis = _subaxes.find(name[0]);
244 475 : return subaxis != _subaxes.end() && subaxis->second->has_variable(name.slice(1));
245 : }
246 :
247 : std::size_t
248 88 : LabeledAxis::variable_id(const LabeledAxisAccessor & name) const
249 : {
250 88 : ensure_setup_dbg();
251 88 : neml_assert(!name.empty(), "Cannot get the ID of a variable with an empty name.");
252 88 : const auto id = _variable_to_id_map.find(name);
253 88 : neml_assert(id != _variable_to_id_map.end(),
254 : "Variable named '",
255 : name,
256 : "' does not exist on axis:\n",
257 : *this);
258 176 : return id->second;
259 : }
260 :
261 : const std::vector<LabeledAxisAccessor> &
262 26872 : LabeledAxis::variable_names() const
263 : {
264 26872 : ensure_setup_dbg();
265 26872 : return _id_to_variable_map;
266 : }
267 :
268 : const std::vector<std::pair<Size, Size>> &
269 16 : LabeledAxis::variable_slices() const
270 : {
271 16 : ensure_setup_dbg();
272 16 : return _id_to_variable_slice_map;
273 : }
274 :
275 : const std::pair<Size, Size> &
276 40 : LabeledAxis::variable_slice(const LabeledAxisAccessor & name) const
277 : {
278 40 : ensure_setup_dbg();
279 40 : return _id_to_variable_slice_map.at(variable_id(name));
280 : }
281 :
282 : const std::vector<Size> &
283 56467 : LabeledAxis::variable_sizes() const
284 : {
285 56467 : ensure_setup_dbg();
286 56467 : return _id_to_variable_size_map;
287 : }
288 :
289 : Size
290 62 : LabeledAxis::variable_size(const LabeledAxisAccessor & name) const
291 : {
292 : // If axis has been set up, return the cached variable size
293 62 : if (_setup)
294 30 : return _id_to_variable_size_map[variable_id(name)];
295 :
296 : // Otherwise, calculate the variable size
297 32 : if (name.size() == 1)
298 : {
299 18 : const auto var = _variables.find(name[0]);
300 18 : neml_assert(
301 18 : var != _variables.end(), "Variable named '", name, "' does not exist on axis:\n", *this);
302 18 : return var->second;
303 : }
304 :
305 14 : const auto subaxis = _subaxes.find(name[0]);
306 14 : neml_assert(
307 14 : subaxis != _subaxes.end(), "Variable named '", name, "' does not exist on axis:\n", *this);
308 14 : return subaxis->second->variable_size(name.slice(1));
309 : }
310 :
311 : std::size_t
312 13 : LabeledAxis::nsubaxis() const
313 : {
314 13 : return _subaxes.size();
315 : }
316 :
317 : bool
318 27 : LabeledAxis::has_subaxis(const LabeledAxisAccessor & name) const
319 : {
320 27 : neml_assert(!name.empty(), "Sub-axis name cannot be empty.");
321 :
322 27 : const auto subaxis = _subaxes.find(name[0]);
323 :
324 27 : if (name.size() == 1)
325 23 : return subaxis != _subaxes.end();
326 :
327 4 : return subaxis->second->has_subaxis(name.slice(1));
328 : }
329 :
330 : std::size_t
331 17 : LabeledAxis::subaxis_id(const std::string & name) const
332 : {
333 17 : ensure_setup_dbg();
334 17 : const auto id = _subaxis_to_id_map.find(name);
335 17 : neml_assert(id != _subaxis_to_id_map.end(),
336 : "Sub-axis named '",
337 : name,
338 : "' does not exist on axis:\n",
339 : *this);
340 34 : return id->second;
341 : }
342 :
343 : const std::vector<const LabeledAxis *> &
344 2 : LabeledAxis::subaxes() const
345 : {
346 2 : ensure_setup_dbg();
347 2 : return _sorted_subaxes;
348 : }
349 :
350 : const LabeledAxis &
351 6659 : LabeledAxis::subaxis(const LabeledAxisAccessor & name) const
352 : {
353 6659 : neml_assert(!name.empty(), "Sub-axis name cannot be empty.");
354 :
355 6659 : const auto subaxis = _subaxes.find(name[0]);
356 6659 : neml_assert(
357 6659 : subaxis != _subaxes.end(), "Sub-axis named '", name, "' does not exist on axis:\n", *this);
358 :
359 6659 : if (name.size() == 1)
360 6655 : return *subaxis->second;
361 :
362 4 : return subaxis->second->subaxis(name.slice(1));
363 : }
364 :
365 : LabeledAxis &
366 6571 : LabeledAxis::subaxis(const LabeledAxisAccessor & name)
367 : {
368 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
369 6571 : return const_cast<LabeledAxis &>(std::as_const(*this).subaxis(name));
370 : }
371 :
372 : const std::vector<std::string> &
373 9 : LabeledAxis::subaxis_names() const
374 : {
375 9 : ensure_setup_dbg();
376 9 : return _id_to_subaxis_map;
377 : }
378 :
379 : const std::vector<std::pair<Size, Size>> &
380 3 : LabeledAxis::subaxis_slices() const
381 : {
382 3 : ensure_setup_dbg();
383 3 : return _id_to_subaxis_slice_map;
384 : }
385 :
386 : std::pair<Size, Size>
387 14 : LabeledAxis::subaxis_slice(const LabeledAxisAccessor & name) const
388 : {
389 14 : ensure_setup_dbg();
390 :
391 : // If the name has length 1, it must be a local sub-axis
392 14 : if (name.size() == 1)
393 10 : return _id_to_subaxis_slice_map[subaxis_id(name[0])];
394 :
395 : // Otherwise, the name must be on a sub-axis
396 4 : const auto subaxis = _subaxes.find(name[0]);
397 4 : neml_assert(
398 4 : subaxis != _subaxes.end(), "Sub-axis named '", name, "' does not exist on axis:\n", *this);
399 4 : const auto & slice = subaxis->second->subaxis_slice(name.slice(1));
400 4 : auto offset = _id_to_subaxis_slice_map[subaxis_id(name[0])].first;
401 4 : return {slice.first + offset, slice.second + offset};
402 : }
403 :
404 : const std::vector<Size> &
405 9 : LabeledAxis::subaxis_sizes() const
406 : {
407 9 : ensure_setup_dbg();
408 9 : return _id_to_subaxis_size_map;
409 : }
410 :
411 : Size
412 14 : LabeledAxis::subaxis_size(const LabeledAxisAccessor & name) const
413 : {
414 14 : const auto subaxis = _subaxes.find(name[0]);
415 14 : neml_assert(
416 14 : subaxis != _subaxes.end(), "Sub-axis named '", name, "' does not exist on axis:\n", *this);
417 :
418 14 : if (name.size() == 1)
419 10 : return subaxis->second->size();
420 :
421 4 : return subaxis->second->subaxis_size(name.slice(1));
422 : }
423 :
424 : bool
425 10 : LabeledAxis::equals(const LabeledAxis & other) const
426 : {
427 : // They must have the same set of variables (with the same storage sizes)
428 10 : if (_variables != other._variables)
429 2 : return false;
430 :
431 : // They must have the same number of subaxes
432 8 : if (_subaxes.size() != other._subaxes.size())
433 0 : return false;
434 :
435 : // Compare each subaxis
436 10 : for (const auto & [name, axis] : _subaxes)
437 : {
438 2 : if (other._subaxes.count(name) == 0)
439 0 : return false;
440 :
441 2 : if (*other._subaxes.at(name) != *axis)
442 0 : return false;
443 : }
444 :
445 8 : return true;
446 : }
447 :
448 : void
449 1550 : LabeledAxis::cache_reserved_subaxis(const std::string & axis_name)
450 : {
451 1550 : if (axis_name == STATE)
452 707 : _has_state = true;
453 843 : else if (axis_name == OLD_STATE)
454 61 : _has_old_state = true;
455 782 : else if (axis_name == FORCES)
456 146 : _has_forces = true;
457 636 : else if (axis_name == OLD_FORCES)
458 70 : _has_old_forces = true;
459 566 : else if (axis_name == RESIDUAL)
460 49 : _has_residual = true;
461 517 : else if (axis_name == PARAMETERS)
462 48 : _has_parameters = true;
463 1550 : }
464 :
465 : void
466 83564 : LabeledAxis::ensure_setup_dbg() const
467 : {
468 83564 : neml_assert_dbg(_setup, "The axis has not been setup yet.");
469 83564 : }
470 :
471 : std::ostream &
472 0 : operator<<(std::ostream & os, const LabeledAxis & axis)
473 : {
474 : // Get unqualified variable names
475 0 : const auto var_names = axis.variable_names();
476 :
477 : // Find the maximum variable name length
478 0 : size_t max_var_name_length = 0;
479 0 : for (const auto & var_name : var_names)
480 : {
481 0 : const auto var_name_str = utils::stringify(var_name);
482 0 : if (var_name_str.size() > max_var_name_length)
483 0 : max_var_name_length = var_name_str.size();
484 0 : }
485 :
486 : // Print variables with right alignment
487 0 : for (auto var = var_names.begin(); var != var_names.end(); var++)
488 : {
489 0 : if (axis._setup)
490 0 : os << std::setw(3) << std::right << axis.variable_id(*var) << ": ";
491 0 : os << std::setw(int(max_var_name_length)) << std::left << utils::stringify(*var);
492 0 : if (axis._setup)
493 0 : os << " [" << axis.variable_slice(*var) << "]";
494 0 : if (std::next(var) != var_names.end())
495 0 : os << std::endl;
496 : }
497 :
498 0 : return os;
499 0 : }
500 :
501 : bool
502 6 : operator==(const LabeledAxis & a, const LabeledAxis & b)
503 : {
504 6 : return a.equals(b);
505 : }
506 :
507 : bool
508 4 : operator!=(const LabeledAxis & a, const LabeledAxis & b)
509 : {
510 4 : return !a.equals(b);
511 : }
512 : } // namespace neml2
|