NEML2 2.1.0
Loading...
Searching...
No Matches
DependencyResolver.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <vector>
28#include <map>
29#include <algorithm>
30
31#include "neml2/models/DependencyDefinition.h"
32#include "neml2/misc/string_utils.h"
33#include "neml2/misc/errors.h"
34
35namespace neml2
36{
45template <typename Node, typename ItemType>
47{
48public:
53 struct Item
54 {
55 Item(Node * const node, ItemType item)
56 : parent(node),
57 value(std::move(item))
58 {
59 }
60
62 Node * const parent;
63
65 const ItemType value;
66
68 bool operator==(const Item & other) const
69 {
70 return parent == other.parent && value == other.value;
71 }
72
74 bool operator!=(const Item & other) const
75 {
76 return parent != other.parent || value != other.value;
77 }
78
80 bool operator<(const Item & other) const
81 {
82 return parent != other.parent ? (parent < other.parent) : (value < other.value);
83 }
84 };
85
86 DependencyResolver() = default;
87
89 void add_node(Node *);
90
92 void add_additional_outbound_item(const ItemType & item);
93
95 void set_priority(Node *, size_t);
96
98 void resolve();
99
101 const std::vector<Node *> & resolution() const { return _resolution; }
102
107 const std::map<Item, std::set<Item>> & item_providers() const { return _item_provider_graph; }
108
113 const std::map<Item, std::set<Item>> & item_consumers() const { return _item_consumer_graph; }
114
119 const std::map<Node *, std::set<Node *>> & node_providers() const { return _node_provider_graph; }
120
125 const std::map<Node *, std::set<Node *>> & node_consumers() const { return _node_consumer_graph; }
126
128 const std::set<Node *> & end_nodes() const { return _end_nodes; }
129
131 const std::set<Item> & outbound_items() const { return _out_items; }
132
134 const std::set<Node *> & start_nodes() const { return _start_nodes; }
135
137 const std::set<Item> & inbound_items() const { return _in_items; }
138
140 bool & unique_item_provider() { return _unique_item_provider; }
141
143 bool & unique_item_consumer() { return _unique_item_consumer; }
144
145private:
147 void build_graph();
148
150 void resolve(Node *);
151
153 bool _unique_item_provider = true;
154
156 bool _unique_item_consumer = false;
157
159 std::set<Node *> _nodes;
160
162 std::set<Item> _consumed_items;
163
165 std::set<Item> _provided_items;
166
169 std::map<Item, std::set<Item>> _item_provider_graph;
170
173 std::map<Item, std::set<Item>> _item_consumer_graph;
174
177 std::map<Node *, std::set<Node *>> _node_provider_graph;
178
181 std::map<Node *, std::set<Node *>> _node_consumer_graph;
182
184 std::set<Node *> _end_nodes;
185
187 std::set<Node *> _start_nodes;
188
190 std::set<Item> _out_items;
191
193 std::set<Item> _in_items;
194
196 std::vector<Node *> _resolution;
197
199 std::map<Node *, int> _status;
200
202 std::map<Node *, size_t> _priority;
203};
204
205template <typename Node, typename ItemType>
206void
208{
209 if (!node)
210 throw NEMLException("Cannot add a nullptr node to the dependency resolver.");
211
212 _nodes.emplace(node);
213
214 auto * def = dynamic_cast<DependencyDefinition<ItemType> *>(node);
215 if (!def)
216 throw NEMLException("Internal error: Node is not derived from DependencyDefinition.");
217
218 for (const auto & item : def->consumed_items())
219 _consumed_items.emplace(node, item);
220 for (const auto & item : def->provided_items())
221 _provided_items.emplace(node, item);
222}
223
224template <typename Node, typename ItemType>
225void
227{
228 _consumed_items.emplace(nullptr, item);
229}
230
231template <typename Node, typename ItemType>
232void
234{
235 _priority[node] = priority;
236}
237
238template <typename Node, typename ItemType>
239void
240DependencyResolver<Node, ItemType>::build_graph()
241{
242 // Clear the previous graph
243 _item_provider_graph.clear();
244 _item_consumer_graph.clear();
245 _node_provider_graph.clear();
246 _node_consumer_graph.clear();
247 _start_nodes.clear();
248 _end_nodes.clear();
249 _in_items.clear();
250 _out_items.clear();
251
252 // Build the adjacency matrix for item providers and node providers
253 for (const auto & itemi : _consumed_items)
254 {
255 std::vector<Item> providers;
256
257 for (const auto & itemj : _provided_items)
258 {
259 // Match consumer with provider
260 if (itemi.value != itemj.value)
261 continue;
262
263 // No self dependency
264 if (itemi.parent == itemj.parent)
265 continue;
266
267 // Enforce priority
268 if (_priority[itemi.parent] > _priority[itemj.parent])
269 continue;
270
271 providers.push_back(itemj);
272 }
273
274 // If the user asks for unique providers, we should error if multiple providers have been
275 // found. Otherwise, just put the first provider into the graph.
276 if (!providers.empty())
277 {
278 if (_unique_item_provider)
279 if (providers.size() != 1)
280 throw NEMLException("Multiple providers have been found for item " +
281 utils::stringify(itemi.value));
282 _item_provider_graph[itemi].insert(providers[0]);
283 if (itemi.parent)
284 _node_provider_graph[itemi.parent].insert(providers[0].parent);
285 }
286 }
287
288 // Build the adjacency matrix for item consumers
289 for (const auto & itemi : _provided_items)
290 {
291 std::vector<Item> consumers;
292
293 for (const auto & itemj : _consumed_items)
294 {
295 // Skip additional outbound item
296 if (!itemj.parent)
297 continue;
298
299 // Match provider with consumer
300 if (itemi.value != itemj.value)
301 continue;
302
303 // No self dependency
304 if (itemi.parent == itemj.parent)
305 continue;
306
307 // Enforce priority
308 if (_priority[itemi.parent] < _priority[itemj.parent])
309 continue;
310
311 consumers.push_back(itemj);
312 }
313
314 // If the user asks for unique consumers, we should error if multiple consumers have been
315 // found. Otherwise, just put the first consumer into the graph.
316 if (!consumers.empty())
317 {
318 if (_unique_item_consumer)
319 if (consumers.size() != 1)
320 throw NEMLException("Multiple consumers have been found for item " +
321 utils::stringify(itemi.value));
322 _item_consumer_graph[itemi].insert(consumers[0]);
323 _node_consumer_graph[itemi.parent].insert(consumers[0].parent);
324 }
325 }
326
327 // Find start nodes
328 for (const auto & node : _nodes)
329 if (_node_provider_graph.count(node) == 0)
330 _start_nodes.insert(node);
331
332 // Find end nodes
333 for (const auto & node : _nodes)
334 if (_node_consumer_graph.count(node) == 0)
335 _end_nodes.insert(node);
336
337 // Find inbound items
338 for (const auto & item : _consumed_items)
339 if (_item_provider_graph.count(item) == 0)
340 _in_items.insert(item);
341
342 // Find outbound items
343 for (const auto & item : _provided_items)
344 if (_item_consumer_graph.count(item) == 0)
345 _out_items.insert(item);
346
347 // Additional outbound items
348 for (const auto & item : _consumed_items)
349 if (!item.parent)
350 {
351 if (!_item_provider_graph.count(item))
352 throw NEMLException("Unable to find provider of the additional outbound item " +
353 utils::stringify(item.value));
354 for (const auto & provider : _item_provider_graph[item])
355 {
356 _out_items.insert(provider);
357 _end_nodes.insert(provider.parent);
358 }
359 }
360}
361
362template <typename Node, typename ItemType>
363void
365{
366 build_graph();
367
368 _status.clear();
369 _resolution.clear();
370 for (const auto & node : _end_nodes)
371 if (!_status[node])
372 resolve(node);
373
374 // Make sure each node appears in the resolution once and only once
375 for (const auto & node : _nodes)
376 {
377 auto count = std::count(_resolution.begin(), _resolution.end(), node);
378 if (count == 0)
379 throw NEMLException(
380 "Each node must appear in the dependency resolution. Node " + node->name() +
381 " is missing. This is an internal error -- consider filing a bug report.");
382 if (count > 1)
383 throw NEMLException(
384 "Each node must appear in the dependency resolution once and only once. Node " +
385 node->name() + " appeared " + std::to_string(count) +
386 " times. This indicates cyclic dependency.");
387 }
388}
389
390template <typename Node, typename ItemType>
391void
393{
394 // Mark the current node as visiting (so that we know there is circular dependency when a
395 // "visiting" node is visited again).
396 _status[node] += 1;
397
398 // Recurse for all the dependent nodes
399 if (_node_provider_graph.count(node))
400 for (const auto & dep : _node_provider_graph[node])
401 {
402 // The dependent node must either be "not visited" or "visited".
403 // If the dependent node is "being visited", there must be cyclic dependency.
404 if (_status[dep] == 1)
405 throw NEMLException(
406 "While resolving dependency, two nodes '" + node->name() + "' and '" + dep->name() +
407 "' have (possibly indirect) cyclic dependency. The cyclic dependency can be "
408 "resolved by explicitly setting the node priorities.");
409
410 if (!_status[dep])
411 resolve(dep);
412 }
413
414 // At this point, all the dependent nodes must have been pushed into the resolution. It is
415 // therefore safe to push the current node into the resolution.
416 _resolution.push_back(node);
417
418 // Finished visiting this node
419 _status[node] += 1;
420}
421} // namespace neml2
Definition DependencyDefinition.h:40
const std::map< Node *, std::set< Node * > > & node_consumers() const
Definition DependencyResolver.h:125
const std::map< Node *, std::set< Node * > > & node_providers() const
Definition DependencyResolver.h:119
const std::set< Item > & inbound_items() const
The items consumed by the overall dependency graph, i.e., the items that are not provided by any node...
Definition DependencyResolver.h:137
void set_priority(Node *, size_t)
Set a node's priority, useful for resolving cyclic dependency.
Definition DependencyResolver.h:233
const std::map< Item, std::set< Item > > & item_consumers() const
Definition DependencyResolver.h:113
void add_additional_outbound_item(const ItemType &item)
Add an additional outbound item that the dependency graph provides.
Definition DependencyResolver.h:226
bool & unique_item_provider()
Definition DependencyResolver.h:140
void resolve()
Resolve nodal dependency and find an evaluation order.
Definition DependencyResolver.h:364
const std::set< Node * > & end_nodes() const
End nodes which are not consumed by anyone else.
Definition DependencyResolver.h:128
const std::vector< Node * > & resolution() const
The resolved (nodal) evaluation order following which all consumed items of the current node.
Definition DependencyResolver.h:101
const std::set< Item > & outbound_items() const
The items provided by the overall dependency graph, i.e., the items that are not consumed by any node...
Definition DependencyResolver.h:131
const std::set< Node * > & start_nodes() const
Start nodes which do not consume anyone else.
Definition DependencyResolver.h:134
void add_node(Node *)
Add a node (defining consumed/provided items) in the dependency graph.
Definition DependencyResolver.h:207
const std::map< Item, std::set< Item > > & item_providers() const
Definition DependencyResolver.h:107
bool & unique_item_consumer()
Definition DependencyResolver.h:143
Definition errors.h:34
std::string stringify(const T &t)
Definition string_utils.h:70
Definition DiagnosticsInterface.h:31
Node *const parent
Node which defines this item.
Definition DependencyResolver.h:62
bool operator<(const Item &other) const
An arbitrary comparator so that items can be sorted (for consistency).
Definition DependencyResolver.h:80
bool operator==(const Item &other) const
Test for equality between two items.
Definition DependencyResolver.h:68
bool operator!=(const Item &other) const
Test for inequality between two items.
Definition DependencyResolver.h:74
const ItemType value
The consumed/provided item.
Definition DependencyResolver.h:65
Item(Node *const node, ItemType item)
Definition DependencyResolver.h:55