NEML2 2.0.0
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
DependencyResolver.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <vector>
28#include <map>
29#include <algorithm>
30
31#include "neml2/models/DependencyDefinition.h"
32#include "neml2/misc/string_utils.h"
33
34namespace neml2
35{
44template <typename Node, typename ItemType>
46{
47public:
52 struct Item
53 {
54 Item(Node * const node, ItemType item)
55 : parent(node),
56 value(std::move(item))
57 {
58 }
59
61 Node * const parent;
62
64 const ItemType value;
65
67 bool operator==(const Item & other) const
68 {
69 return parent == other.parent && value == other.value;
70 }
71
73 bool operator!=(const Item & other) const
74 {
75 return parent != other.parent || value != other.value;
76 }
77
79 bool operator<(const Item & other) const
80 {
81 return parent != other.parent ? (parent < other.parent) : (value < other.value);
82 }
83 };
84
85 DependencyResolver() = default;
86
89
91 void add_additional_outbound_item(const ItemType & item);
92
95
97 void resolve();
98
100 const std::vector<Node *> & resolution() const { return _resolution; }
101
106 const std::map<Item, std::set<Item>> & item_providers() const { return _item_provider_graph; }
107
112 const std::map<Item, std::set<Item>> & item_consumers() const { return _item_consumer_graph; }
113
118 const std::map<Node *, std::set<Node *>> & node_providers() const { return _node_provider_graph; }
119
124 const std::map<Node *, std::set<Node *>> & node_consumers() const { return _node_consumer_graph; }
125
127 const std::set<Node *> & end_nodes() const { return _end_nodes; }
128
130 const std::set<Item> & outbound_items() const { return _out_items; }
131
133 const std::set<Node *> & start_nodes() const { return _start_nodes; }
134
136 const std::set<Item> & inbound_items() const { return _in_items; }
137
139 bool & unique_item_provider() { return _unique_item_provider; }
140
142 bool & unique_item_consumer() { return _unique_item_consumer; }
143
144private:
146 void build_graph();
147
149 void resolve(Node *);
150
152 bool _unique_item_provider = true;
153
155 bool _unique_item_consumer = false;
156
158 std::set<Node *> _nodes;
159
161 std::set<Item> _consumed_items;
162
164 std::set<Item> _provided_items;
165
168 std::map<Item, std::set<Item>> _item_provider_graph;
169
172 std::map<Item, std::set<Item>> _item_consumer_graph;
173
176 std::map<Node *, std::set<Node *>> _node_provider_graph;
177
180 std::map<Node *, std::set<Node *>> _node_consumer_graph;
181
183 std::set<Node *> _end_nodes;
184
186 std::set<Node *> _start_nodes;
187
189 std::set<Item> _out_items;
190
192 std::set<Item> _in_items;
193
195 std::vector<Node *> _resolution;
196
198 std::map<Node *, int> _status;
199
201 std::map<Node *, size_t> _priority;
202};
203
204template <typename Node, typename ItemType>
205void
207{
208 auto node = dynamic_cast<Node *>(def);
209 _nodes.emplace(node);
210
211 for (const auto & item : node->consumed_items())
212 _consumed_items.emplace(node, item);
213
214 for (const auto & item : node->provided_items())
215 _provided_items.emplace(node, item);
216}
217
218template <typename Node, typename ItemType>
219void
221{
222 _consumed_items.emplace(nullptr, item);
223}
224
225template <typename Node, typename ItemType>
226void
228 size_t priority)
229{
230 auto node = dynamic_cast<Node *>(def);
231 _priority[node] = priority;
232}
233
234template <typename Node, typename ItemType>
235void
236DependencyResolver<Node, ItemType>::build_graph()
237{
238 // Clear the previous graph
239 _item_provider_graph.clear();
240 _item_consumer_graph.clear();
241 _node_provider_graph.clear();
242 _node_consumer_graph.clear();
243 _start_nodes.clear();
244 _end_nodes.clear();
245 _in_items.clear();
246 _out_items.clear();
247
248 // Build the adjacency matrix for item providers and node providers
249 for (const auto & itemi : _consumed_items)
250 {
251 std::vector<Item> providers;
252
253 for (const auto & itemj : _provided_items)
254 {
255 // Match consumer with provider
256 if (itemi.value != itemj.value)
257 continue;
258
259 // No self dependency
260 if (itemi.parent == itemj.parent)
261 continue;
262
263 // Enforce priority
264 if (_priority[itemi.parent] > _priority[itemj.parent])
265 continue;
266
267 providers.push_back(itemj);
268 }
269
270 // If the user asks for unique providers, we should error if multiple providers have been
271 // found. Otherwise, just put the first provider into the graph.
272 if (!providers.empty())
273 {
274 if (_unique_item_provider)
275 if (providers.size() != 1)
276 throw NEMLException("Multiple providers have been found for item " +
277 utils::stringify(itemi.value));
278 _item_provider_graph[itemi].insert(providers[0]);
279 if (itemi.parent)
280 _node_provider_graph[itemi.parent].insert(providers[0].parent);
281 }
282 }
283
284 // Build the adjacency matrix for item consumers
285 for (const auto & itemi : _provided_items)
286 {
287 std::vector<Item> consumers;
288
289 for (const auto & itemj : _consumed_items)
290 {
291 // Skip additional outbound item
292 if (!itemj.parent)
293 continue;
294
295 // Match provider with consumer
296 if (itemi.value != itemj.value)
297 continue;
298
299 // No self dependency
300 if (itemi.parent == itemj.parent)
301 continue;
302
303 // Enforce priority
304 if (_priority[itemi.parent] < _priority[itemj.parent])
305 continue;
306
307 consumers.push_back(itemj);
308 }
309
310 // If the user asks for unique consumers, we should error if multiple consumers have been
311 // found. Otherwise, just put the first consumer into the graph.
312 if (!consumers.empty())
313 {
314 if (_unique_item_consumer)
315 if (consumers.size() != 1)
316 throw NEMLException("Multiple consumers have been found for item " +
317 utils::stringify(itemi.value));
318 _item_consumer_graph[itemi].insert(consumers[0]);
319 _node_consumer_graph[itemi.parent].insert(consumers[0].parent);
320 }
321 }
322
323 // Find start nodes
324 for (const auto & node : _nodes)
325 if (_node_provider_graph.count(node) == 0)
326 _start_nodes.insert(node);
327
328 // Find end nodes
329 for (const auto & node : _nodes)
330 if (_node_consumer_graph.count(node) == 0)
331 _end_nodes.insert(node);
332
333 // Find inbound items
334 for (const auto & item : _consumed_items)
335 if (_item_provider_graph.count(item) == 0)
336 _in_items.insert(item);
337
338 // Find outbound items
339 for (const auto & item : _provided_items)
340 if (_item_consumer_graph.count(item) == 0)
341 _out_items.insert(item);
342
343 // Additional outbound items
344 for (const auto & item : _consumed_items)
345 if (!item.parent)
346 {
347 if (!_item_provider_graph.count(item))
348 throw NEMLException("Unable to find provider of the additional outbound item " +
349 utils::stringify(item.value));
350 for (const auto & provider : _item_provider_graph[item])
351 {
352 _out_items.insert(provider);
353 _end_nodes.insert(provider.parent);
354 }
355 }
356}
357
358template <typename Node, typename ItemType>
359void
361{
362 build_graph();
363
364 _status.clear();
365 _resolution.clear();
366 for (const auto & node : _end_nodes)
367 if (!_status[node])
368 resolve(node);
369
370 // Make sure each node appears in the resolution once and only once
371 for (const auto & node : _nodes)
372 {
373 auto count = std::count(_resolution.begin(), _resolution.end(), node);
374 if (count == 0)
375 throw NEMLException(
376 "Each node must appear in the dependency resolution. Node " + node->name() +
377 " is missing. This is an internal error -- consider filing a bug report.");
378 if (count > 1)
379 throw NEMLException(
380 "Each node must appear in the dependency resolution once and only once. Node " +
381 node->name() + " appeared " + std::to_string(count) +
382 " times. This indicates cyclic dependency.");
383 }
384}
385
386template <typename Node, typename ItemType>
387void
389{
390 // Mark the current node as visiting (so that we know there is circular dependency when a
391 // "visiting" node is visited again).
392 _status[node] += 1;
393
394 // Recurse for all the dependent nodes
395 if (_node_provider_graph.count(node))
396 for (const auto & dep : _node_provider_graph[node])
397 {
398 // The dependent node must either be "not visited" or "visited".
399 // If the dependent node is "being visited", there must be cyclic dependency.
400 if (_status[dep] == 1)
401 throw NEMLException(
402 "While resolving dependency, two nodes '" + node->name() + "' and '" + dep->name() +
403 "' have (possibly indirect) cyclic dependency. The cyclic dependency can be "
404 "resolved by explicitly setting the node priorities.");
405
406 if (!_status[dep])
407 resolve(dep);
408 }
409
410 // At this point, all the dependent nodes must have been pushed into the resolution. It is
411 // therefore safe to push the current node into the resolution.
412 _resolution.push_back(node);
413
414 // Finished visiting this node
415 _status[node] += 1;
416}
417} // namespace neml2
Definition DependencyDefinition.h:40
const std::map< Node *, std::set< Node * > > & node_consumers() const
Definition DependencyResolver.h:124
const std::map< Node *, std::set< Node * > > & node_providers() const
Definition DependencyResolver.h:118
const std::set< Item > & inbound_items() const
The items consumed by the overall dependency graph, i.e., the items that are not provided by any node...
Definition DependencyResolver.h:136
const std::map< Item, std::set< Item > > & item_consumers() const
Definition DependencyResolver.h:112
void add_additional_outbound_item(const ItemType &item)
Add an additional outbound item that the dependency graph provides
Definition DependencyResolver.h:220
void set_priority(DependencyDefinition< ItemType > *, size_t)
Set a node's priority, useful for resolving cyclic dependency.
Definition DependencyResolver.h:227
bool & unique_item_provider()
Definition DependencyResolver.h:139
void resolve()
Resolve nodal dependency and find an evaluation order.
Definition DependencyResolver.h:360
const std::set< Node * > & end_nodes() const
End nodes which are not consumed by anyone else.
Definition DependencyResolver.h:127
const std::vector< Node * > & resolution() const
The resolved (nodal) evaluation order following which all consumed items of the current node.
Definition DependencyResolver.h:100
const std::set< Item > & outbound_items() const
The items provided by the overall dependency graph, i.e., the items that are not consumed by any node...
Definition DependencyResolver.h:130
void add_node(DependencyDefinition< ItemType > *)
Add a node (defining consumed/provided items) in the dependency graph.
Definition DependencyResolver.h:206
const std::set< Node * > & start_nodes() const
Start nodes which do not consume anyone else.
Definition DependencyResolver.h:133
const std::map< Item, std::set< Item > > & item_providers() const
Definition DependencyResolver.h:106
bool & unique_item_consumer()
Definition DependencyResolver.h:142
Definition errors.h:34
std::string stringify(const T &t)
Definition string_utils.h:73
Definition DiagnosticsInterface.cxx:30
Node *const parent
Node which defines this item.
Definition DependencyResolver.h:61
bool operator<(const Item &other) const
An arbitrary comparator so that items can be sorted (for consistency)
Definition DependencyResolver.h:79
bool operator==(const Item &other) const
Test for equality between two items.
Definition DependencyResolver.h:67
bool operator!=(const Item &other) const
Test for inequality between two items.
Definition DependencyResolver.h:73
const ItemType value
The consumed/provided item.
Definition DependencyResolver.h:64
Item(Node *const node, ItemType item)
Definition DependencyResolver.h:54