Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Create predecessors() and successors() on ir.Node #2022

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
50 changes: 45 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Hashable,
Iterable,
Iterator,
NamedTuple,
OrderedDict,
Sequence,
SupportsInt,
Expand Down Expand Up @@ -1055,6 +1056,18 @@ def _quoted(string: str) -> str:
return f'"{string}"'


class Usage(NamedTuple):
"""A usage of a value in a node.

Attributes:
node: The node that uses the value.
idx: The input index of the value in the node.
"""

node: Node
idx: int


class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
"""IR Node.

Expand Down Expand Up @@ -1293,6 +1306,25 @@ def inputs(self, _: Any) -> None:
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
predecessors: dict[Node, None] = {}
for value in self.inputs:
if value is not None and (producer := value.producer()) is not None:
predecessors[producer] = None
return tuple(predecessors)
justinchuby marked this conversation as resolved.
Show resolved Hide resolved

def successors(self) -> Sequence[Node]:
"""Return the successor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
successors: dict[Node, None] = {}
for value in self.outputs:
assert value is not None, "Bug: Output values are not expected to be None"
for usage in value.uses():
successors[usage.node] = None
return tuple(successors)
justinchuby marked this conversation as resolved.
Show resolved Hide resolved

def replace_input_with(self, index: int, value: Value | None) -> None:
"""Replace an input with a new value."""
if index < 0 or index >= len(self.inputs):
Expand Down Expand Up @@ -1564,7 +1596,7 @@ def __init__(
# Use a collection of (Node, int) to store uses. This is needed
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._uses: dict[tuple[Node, int], None] = {}
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

def __repr__(self) -> str:
Expand Down Expand Up @@ -1595,31 +1627,39 @@ def producer(self) -> Node | None:
"""
return self._producer

def consumers(self) -> Sequence[Node]:
"""Return the nodes (deduplicated) that consume this value."""
return tuple({usage.node: None for usage in self._uses})
justinchuby marked this conversation as resolved.
Show resolved Hide resolved

def index(self) -> int | None:
"""The index of the output of the defining node."""
return self._index

def uses(self) -> Collection[tuple[Node, int]]:
def uses(self) -> Collection[Usage]:
"""Return a set of uses of the value.

The set contains tuples of ``(Node, index)`` where the index is the index of the input
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
"""
return self._uses.keys()
# Create a tuple for the collection so that iteration on will will not
# be affected when the usage changes during graph mutation.
# This adds a small overhead but is better a user experience than
# having users call tuple().
return tuple(self._uses)

def _add_usage(self, use: Node, index: int) -> None:
"""Add a usage of this value.

This is an internal method. It should only be called by the Node class.
"""
self._uses[(use, index)] = None
self._uses[Usage(use, index)] = None

def _remove_usage(self, use: Node, index: int) -> None:
"""Remove a node from the uses of this value.

This is an internal method. It should only be called by the Node class.
"""
self._uses.pop((use, index))
self._uses.pop(Usage(use, index))

@property
def name(self) -> str | None:
Expand Down
48 changes: 44 additions & 4 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,13 @@ def test_is_dynamic_on_empty_shape(self):


class ValueTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=2
)

def test_initialize(self):
_ = _core.Value()

Expand All @@ -732,14 +739,30 @@ def test_meta(self):
value.metadata_props["test"] = "any string"
self.assertEqual(value.metadata_props["test"], "any string")

def test_producer(self):
self.assertEqual(self.v0.producer(), None)
self.assertEqual(self.v1.producer(), None)
self.assertEqual(self.node.outputs[0].producer(), self.node)
self.assertEqual(self.node.outputs[1].producer(), self.node)

def test_consumers(self):
self.assertEqual(self.v0.consumers(), (self.node,))
self.assertEqual(self.v1.consumers(), (self.node,))
self.assertEqual(self.node.outputs[0].consumers(), ())
self.assertEqual(self.node.outputs[1].consumers(), ())

# TODO(justinchuby): Test all methods


class NodeTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value()
self.v1 = _core.Value()
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=3
)
self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]])
self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs)

def test_it_is_hashable(self):
self.assertIsInstance(hash(self.node), int)
Expand All @@ -748,7 +771,7 @@ def test_it_is_hashable(self):
def test_init_with_values(self):
self.assertEqual(self.node.domain, "test")
self.assertEqual(self.node.op_type, "TestOp")
self.assertEqual(self.node.inputs, (self.v0, self.v1))
self.assertEqual(self.node.inputs, (self.v0, self.v1, self.v1))
self.assertEqual(len(self.node.outputs), 3)
self.assertEqual(self.node.attributes, {})

Expand Down Expand Up @@ -807,6 +830,23 @@ def test_it_is_added_to_a_graph_if_specified(self):
)
self.assertIn(self.node, graph)

def test_predecessors(self):
self.assertEqual(self.node.predecessors(), ())
self.assertEqual(self.node_a.predecessors(), (self.node,))
self.assertEqual(self.node_b.predecessors(), (self.node,))

def test_predecessors_are_unique(self):
# node_b has three inputs from node, but only one predecessor
self.assertEqual(self.node_b.predecessors(), self.node_a.predecessors())

def test_successors(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))
self.assertEqual(self.node_a.successors(), ())
self.assertEqual(self.node_b.successors(), ())

def test_successors_are_unique(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))

# TODO(justinchuby): Test all methods


Expand Down
Loading