31#include "neml2/models/DependencyDefinition.h"
32#include "neml2/misc/string_utils.h"
33#include "neml2/misc/errors.h"
45template <
typename Node,
typename ItemType>
55 Item(Node *
const node, ItemType item)
57 value(std::move(item))
101 const std::vector<Node *> &
resolution()
const {
return _resolution; }
107 const std::map<Item, std::set<Item>> &
item_providers()
const {
return _item_provider_graph; }
113 const std::map<Item, std::set<Item>> &
item_consumers()
const {
return _item_consumer_graph; }
119 const std::map<Node *, std::set<Node *>> &
node_providers()
const {
return _node_provider_graph; }
125 const std::map<Node *, std::set<Node *>> &
node_consumers()
const {
return _node_consumer_graph; }
128 const std::set<Node *> &
end_nodes()
const {
return _end_nodes; }
134 const std::set<Node *> &
start_nodes()
const {
return _start_nodes; }
153 bool _unique_item_provider =
true;
156 bool _unique_item_consumer =
false;
159 std::set<Node *> _nodes;
162 std::set<Item> _consumed_items;
165 std::set<Item> _provided_items;
169 std::map<Item, std::set<Item>> _item_provider_graph;
173 std::map<Item, std::set<Item>> _item_consumer_graph;
177 std::map<Node *, std::set<Node *>> _node_provider_graph;
181 std::map<Node *, std::set<Node *>> _node_consumer_graph;
184 std::set<Node *> _end_nodes;
187 std::set<Node *> _start_nodes;
190 std::set<Item> _out_items;
193 std::set<Item> _in_items;
196 std::vector<Node *> _resolution;
199 std::map<Node *, int> _status;
202 std::map<Node *, size_t> _priority;
205template <
typename Node,
typename ItemType>
210 throw NEMLException(
"Cannot add a nullptr node to the dependency resolver.");
212 _nodes.emplace(node);
216 throw NEMLException(
"Internal error: Node is not derived from DependencyDefinition.");
218 for (
const auto & item : def->consumed_items())
219 _consumed_items.emplace(node, item);
220 for (
const auto & item : def->provided_items())
221 _provided_items.emplace(node, item);
224template <
typename Node,
typename ItemType>
228 _consumed_items.emplace(
nullptr, item);
231template <
typename Node,
typename ItemType>
235 _priority[node] = priority;
238template <
typename Node,
typename ItemType>
240DependencyResolver<Node, ItemType>::build_graph()
243 _item_provider_graph.clear();
244 _item_consumer_graph.clear();
245 _node_provider_graph.clear();
246 _node_consumer_graph.clear();
247 _start_nodes.clear();
253 for (
const auto & itemi : _consumed_items)
255 std::vector<Item> providers;
257 for (
const auto & itemj : _provided_items)
260 if (itemi.value != itemj.value)
264 if (itemi.parent == itemj.parent)
268 if (_priority[itemi.parent] > _priority[itemj.parent])
271 providers.push_back(itemj);
276 if (!providers.empty())
278 if (_unique_item_provider)
279 if (providers.size() != 1)
280 throw NEMLException(
"Multiple providers have been found for item " +
282 _item_provider_graph[itemi].insert(providers[0]);
284 _node_provider_graph[itemi.parent].insert(providers[0].parent);
289 for (
const auto & itemi : _provided_items)
291 std::vector<Item> consumers;
293 for (
const auto & itemj : _consumed_items)
300 if (itemi.value != itemj.value)
304 if (itemi.parent == itemj.parent)
308 if (_priority[itemi.parent] < _priority[itemj.parent])
311 consumers.push_back(itemj);
316 if (!consumers.empty())
318 if (_unique_item_consumer)
319 if (consumers.size() != 1)
320 throw NEMLException(
"Multiple consumers have been found for item " +
322 _item_consumer_graph[itemi].insert(consumers[0]);
323 _node_consumer_graph[itemi.parent].insert(consumers[0].parent);
328 for (
const auto & node : _nodes)
329 if (_node_provider_graph.count(node) == 0)
330 _start_nodes.insert(node);
333 for (
const auto & node : _nodes)
334 if (_node_consumer_graph.count(node) == 0)
335 _end_nodes.insert(node);
338 for (
const auto & item : _consumed_items)
339 if (_item_provider_graph.count(item) == 0)
340 _in_items.insert(item);
343 for (
const auto & item : _provided_items)
344 if (_item_consumer_graph.count(item) == 0)
345 _out_items.insert(item);
348 for (
const auto & item : _consumed_items)
351 if (!_item_provider_graph.count(item))
352 throw NEMLException(
"Unable to find provider of the additional outbound item " +
354 for (
const auto & provider : _item_provider_graph[item])
356 _out_items.insert(provider);
357 _end_nodes.insert(provider.parent);
362template <
typename Node,
typename ItemType>
370 for (
const auto & node : _end_nodes)
375 for (
const auto & node : _nodes)
377 auto count = std::count(_resolution.begin(), _resolution.end(), node);
380 "Each node must appear in the dependency resolution. Node " + node->name() +
381 " is missing. This is an internal error -- consider filing a bug report.");
384 "Each node must appear in the dependency resolution once and only once. Node " +
385 node->name() +
" appeared " + std::to_string(count) +
386 " times. This indicates cyclic dependency.");
390template <
typename Node,
typename ItemType>
399 if (_node_provider_graph.count(node))
400 for (
const auto & dep : _node_provider_graph[node])
404 if (_status[dep] == 1)
406 "While resolving dependency, two nodes '" + node->name() +
"' and '" + dep->name() +
407 "' have (possibly indirect) cyclic dependency. The cyclic dependency can be "
408 "resolved by explicitly setting the node priorities.");
416 _resolution.push_back(node);
Definition DependencyDefinition.h:40
const std::map< Node *, std::set< Node * > > & node_consumers() const
Definition DependencyResolver.h:125
const std::map< Node *, std::set< Node * > > & node_providers() const
Definition DependencyResolver.h:119
const std::set< Item > & inbound_items() const
The items consumed by the overall dependency graph, i.e., the items that are not provided by any node...
Definition DependencyResolver.h:137
void set_priority(Node *, size_t)
Set a node's priority, useful for resolving cyclic dependency.
Definition DependencyResolver.h:233
const std::map< Item, std::set< Item > > & item_consumers() const
Definition DependencyResolver.h:113
DependencyResolver()=default
void add_additional_outbound_item(const ItemType &item)
Add an additional outbound item that the dependency graph provides.
Definition DependencyResolver.h:226
bool & unique_item_provider()
Definition DependencyResolver.h:140
void resolve()
Resolve nodal dependency and find an evaluation order.
Definition DependencyResolver.h:364
const std::set< Node * > & end_nodes() const
End nodes which are not consumed by anyone else.
Definition DependencyResolver.h:128
const std::vector< Node * > & resolution() const
The resolved (nodal) evaluation order following which all consumed items of the current node.
Definition DependencyResolver.h:101
const std::set< Item > & outbound_items() const
The items provided by the overall dependency graph, i.e., the items that are not consumed by any node...
Definition DependencyResolver.h:131
const std::set< Node * > & start_nodes() const
Start nodes which do not consume anyone else.
Definition DependencyResolver.h:134
void add_node(Node *)
Add a node (defining consumed/provided items) in the dependency graph.
Definition DependencyResolver.h:207
const std::map< Item, std::set< Item > > & item_providers() const
Definition DependencyResolver.h:107
bool & unique_item_consumer()
Definition DependencyResolver.h:143
std::string stringify(const T &t)
Definition string_utils.h:70
Definition DiagnosticsInterface.h:31
Node *const parent
Node which defines this item.
Definition DependencyResolver.h:62
bool operator<(const Item &other) const
An arbitrary comparator so that items can be sorted (for consistency).
Definition DependencyResolver.h:80
bool operator==(const Item &other) const
Test for equality between two items.
Definition DependencyResolver.h:68
bool operator!=(const Item &other) const
Test for inequality between two items.
Definition DependencyResolver.h:74
const ItemType value
The consumed/provided item.
Definition DependencyResolver.h:65
Item(Node *const node, ItemType item)
Definition DependencyResolver.h:55