Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute topological sort with dynamic key #145

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions pytools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
.. autofunction:: compute_sccs
.. autoclass:: CycleError
.. autofunction:: compute_topological_order
.. autofunction:: compute_topological_order_with_dynamic_key
.. autofunction:: compute_transitive_closure
.. autofunction:: contains_cycle
.. autofunction:: compute_induced_subgraph
.. autoclass:: TopologicalOrderState

Type Variables Used
-------------------
Expand All @@ -46,7 +48,8 @@
"""

from typing import (TypeVar, Mapping, Iterable, List, Optional, Any, Callable,
Set, MutableSet, Dict, Iterator, Tuple)
Set, MutableSet, Dict, Iterator, Tuple, Generic)
from dataclasses import dataclass


T = TypeVar("T")
Expand Down Expand Up @@ -207,18 +210,44 @@ def __lt__(self, other):
return self.key < other.key


def compute_topological_order(graph: Mapping[T, Iterable[T]],
key: Optional[Callable[[T], Any]] = None) -> List[T]:
"""Compute a topological order of nodes in a directed graph.
@dataclass(frozen=True)
class TopologicalOrderState(Generic[T]):
"""
.. attribute:: scheduled_nodes

A :class:`list` of nodes that have been scheduled.

.. warning::

- Mutable updates to :attr:`scheduled_nodes`
results in an undefined behavior.
"""
scheduled_nodes: List[T]


def compute_topological_order_with_dynamic_key(
graph: Mapping[T, Iterable[T]],
trigger_key_update: Callable[[TopologicalOrderState[T]], bool],
get_key: Callable[[TopologicalOrderState[T]], Callable[[T], Any]]
) -> List[T]:
"""
Computes a topological order of nodes in a directed graph with support for
a dynamic keying function.

:arg graph: A :class:`collections.abc.Mapping` representing a directed
graph. The dictionary contains one key representing each node in the
graph, and this key maps to a :class:`collections.abc.Iterable` of its
successor nodes.

:arg key: A custom key function may be supplied to determine the order in
break-even cases. Expects a function of one argument that is used to
extract a comparison key from each node of the *graph*.
:arg trigger_key_update: A function called after scheduling a node in
*graph* that takes in an instance of :class:`TopologicalOrderState`
corresponding to the scheduling state at that point and returns whether
the comparison keys corresponding to the nodes be updated.

:arg get_key: A callable called when *trigger_key_update*
returns *True*. Takes in an instance of :class:`TopologicalOrderState`
and returns another callable that accepts node as an argument and returns the
comparison key corresponding to the node.
Comment on lines +242 to +250
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this be a single callable that either performs the update or says "I've not done anything"?


:returns: A :class:`list` representing a valid topological ordering of the
nodes in the directed graph.
Expand All @@ -228,10 +257,8 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
* Requires the keys of the mapping *graph* to be hashable.
* Implements `Kahn's algorithm <https://w.wiki/YDy>`__.

.. versionadded:: 2020.2
.. versionadded:: 2022.2
"""
# all nodes have the same keys when not provided
keyfunc = key if key is not None else (lambda x: 0)

from heapq import heapify, heappop, heappush

Expand All @@ -248,6 +275,8 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],

# }}}

keyfunc = get_key(TopologicalOrderState(scheduled_nodes=[]))

total_num_nodes = len(nodes_to_num_predecessors)

# heap: list of instances of HeapEntry(n) where 'n' is a node in
Expand All @@ -263,6 +292,14 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],
node_to_be_scheduled = heappop(heap).node
order.append(node_to_be_scheduled)

state = TopologicalOrderState(scheduled_nodes=order)

if trigger_key_update(state):
keyfunc = get_key(state)
heap = [HeapEntry(entry.node, keyfunc(entry.node))
for entry in heap]
heapify(heap)

# discard 'node_to_be_scheduled' from the predecessors of its
# successors since it's been scheduled
for child in graph.get(node_to_be_scheduled, ()):
Expand All @@ -277,6 +314,38 @@ def compute_topological_order(graph: Mapping[T, Iterable[T]],

return order


def compute_topological_order(graph: Mapping[T, Iterable[T]],
key: Optional[Callable[[T], Any]] = None) -> List[T]:
"""Compute a topological order of nodes in a directed graph.

:arg graph: A :class:`collections.abc.Mapping` representing a directed
graph. The dictionary contains one key representing each node in the
graph, and this key maps to a :class:`collections.abc.Iterable` of its
successor nodes.

:arg key: A custom key function may be supplied to determine the order in
break-even cases. Expects a function of one argument that is used to
extract a comparison key from each node of the *graph*.

:returns: A :class:`list` representing a valid topological ordering of the
nodes in the directed graph.

.. note::

* Requires the keys of the mapping *graph* to be hashable.
* Implements `Kahn's algorithm <https://w.wiki/YDy>`__.

.. versionadded:: 2020.2
"""
# all nodes have the same keys when not provided
keyfunc = key if key is not None else (lambda x: 0)

return compute_topological_order_with_dynamic_key(
graph,
trigger_key_update=lambda _: False,
get_key=lambda _: keyfunc)

# }}}


Expand Down
2 changes: 1 addition & 1 deletion pytools/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (2022, 1, 12)
VERSION = (2022, 2)
VERSION_STATUS = ""
VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS
44 changes: 44 additions & 0 deletions test/test_pytools.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,50 @@ class Eq3(Tag):
assert hash(eq1) != hash(eq3)


def test_compute_topological_order_with_dynamic_key():
from pytools.graph import compute_topological_order_with_dynamic_key

dag = {"A": {"C"},
"B": {"E"},
"C": {"D"},
"D": set(),
"E": set(),
}

colors = {"A": "red",
"B": "red",
"C": "blue",
"D": "red",
"E": "blue"}

# {{{ set a dynamic key to greedily schedule continuous chunks of blue/red nodes

def trigger_key_update(state):
if len(state.scheduled_nodes) == 1:
# initially we preferred blue.
return colors[state.scheduled_nodes[0]] == "red"
else:
return (colors[state.scheduled_nodes[-1]]
!= colors[state.scheduled_nodes[-2]])

def get_key(state):
if len(state.scheduled_nodes) == 0:
# initial state => prefer blue.
return lambda x: (colors[x] != "blue",
x)
else:
return lambda x: (colors[x] != colors[state.scheduled_nodes[-1]],
x)

# }}}

sorted_nodes = compute_topological_order_with_dynamic_key(
dag,
trigger_key_update, get_key)

assert sorted_nodes == ["A", "B", "C", "E", "D"]


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down