Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Create predecessors() and successors() on ir.Node #2022

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
50 changes: 45 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Hashable,
Iterable,
Iterator,
NamedTuple,
OrderedDict,
Sequence,
SupportsInt,
Expand Down Expand Up @@ -1055,6 +1056,18 @@
return f'"{string}"'


class Usage(NamedTuple):
"""A usage of a value in a node.

Attributes:
node: The node that uses the value.
idx: The input index of the value in the node.
"""

node: Node
idx: int


class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
"""IR Node.

Expand Down Expand Up @@ -1293,6 +1306,25 @@
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated, in a determinsitic order."""

Check warning on line 1310 in onnxscript/ir/_core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "determinsitic" is a misspelling of "deterministic" Raw Output: ./onnxscript/ir/_core.py:1310:72: "determinsitic" is a misspelling of "deterministic"
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
# Use the ordered nature of a dictionary to deduplicate the nodes
predecessors: dict[Node, None] = {}
for value in self.inputs:
if value is not None and (producer := value.producer()) is not None:
predecessors[producer] = None

Check warning on line 1315 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1315

Added line #L1315 was not covered by tests
return tuple(predecessors)
justinchuby marked this conversation as resolved.
Show resolved Hide resolved

def successors(self) -> Sequence[Node]:
"""Return the successor nodes of the node, deduplicated, in a determinsitic order."""

Check warning on line 1319 in onnxscript/ir/_core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "determinsitic" is a misspelling of "deterministic" Raw Output: ./onnxscript/ir/_core.py:1319:70: "determinsitic" is a misspelling of "deterministic"
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
# Use the ordered nature of a dictionary to deduplicate the nodes
successors: dict[Node, None] = {}

Check warning on line 1321 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1321

Added line #L1321 was not covered by tests
for value in self.outputs:
assert value is not None, "Bug: Output values are not expected to be None"

Check warning on line 1323 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1323

Added line #L1323 was not covered by tests
for usage in value.uses():
successors[usage.node] = None
return tuple(successors)

Check warning on line 1326 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1325-L1326

Added lines #L1325 - L1326 were not covered by tests
justinchuby marked this conversation as resolved.
Show resolved Hide resolved

def replace_input_with(self, index: int, value: Value | None) -> None:
"""Replace an input with a new value."""
if index < 0 or index >= len(self.inputs):
Expand Down Expand Up @@ -1564,7 +1596,7 @@
# Use a collection of (Node, int) to store uses. This is needed
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._uses: dict[tuple[Node, int], None] = {}
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

def __repr__(self) -> str:
Expand Down Expand Up @@ -1595,31 +1627,39 @@
"""
return self._producer

def consumers(self) -> Sequence[Node]:
"""Return the nodes (deduplicated) that consume this value."""
return tuple({usage.node: None for usage in self._uses})

Check warning on line 1632 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1632

Added line #L1632 was not covered by tests
justinchuby marked this conversation as resolved.
Show resolved Hide resolved

def index(self) -> int | None:
"""The index of the output of the defining node."""
return self._index

def uses(self) -> Collection[tuple[Node, int]]:
def uses(self) -> Collection[Usage]:
"""Return a set of uses of the value.

The set contains tuples of ``(Node, index)`` where the index is the index of the input
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
"""
return self._uses.keys()
# Create a tuple for the collection so that iteration on will will not
# be affected when the usage changes during graph mutation.
# This addes a small overhead but is better a user experience than
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
# having users call tuple().
return tuple(self._uses)

def _add_usage(self, use: Node, index: int) -> None:
"""Add a usage of this value.

This is an internal method. It should only be called by the Node class.
"""
self._uses[(use, index)] = None
self._uses[Usage(use, index)] = None

def _remove_usage(self, use: Node, index: int) -> None:
"""Remove a node from the uses of this value.

This is an internal method. It should only be called by the Node class.
"""
self._uses.pop((use, index))
self._uses.pop(Usage(use, index))

@property
def name(self) -> str | None:
Expand Down
5 changes: 5 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,8 @@ def setUp(self) -> None:
self.v0 = _core.Value()
self.v1 = _core.Value()
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)
self.node_a = _core.Node("test", "TestOpA", inputs=self.node.outputs)
self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs)

def test_it_is_hashable(self):
self.assertIsInstance(hash(self.node), int)
Expand Down Expand Up @@ -807,6 +809,9 @@ def test_it_is_added_to_a_graph_if_specified(self):
)
self.assertIn(self.node, graph)

def test_predecessors(self):
self.assertEqual(self.node.predecessors(), ())

# TODO(justinchuby): Test all methods


Expand Down
Loading