Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Create predecessors() and successors() on ir.Node #2022

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
46 changes: 42 additions & 4 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Hashable,
Iterable,
Iterator,
NamedTuple,
OrderedDict,
Sequence,
SupportsInt,
Expand Down Expand Up @@ -1055,6 +1056,18 @@
return f'"{string}"'


class Usage(NamedTuple):
"""A usage of a value in a node.

Attributes:
node: The node that uses the value.
index: The input index of the value in the node.
"""

node: Node
index: int
Fixed Show fixed Hide fixed


class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
"""IR Node.

Expand Down Expand Up @@ -1293,6 +1306,31 @@
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated."""
predecessors = []
seen = set()

Check warning on line 1312 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1311-L1312

Added lines #L1311 - L1312 were not covered by tests
for value in self.inputs:
if value is not None and (producer := value.producer()) is not None:
if producer in seen:
continue
seen.add(producer)
predecessors.append(producer)
return predecessors

Check warning on line 1319 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1316-L1319

Added lines #L1316 - L1319 were not covered by tests

def successors(self) -> Sequence[Node]:
"""Return the successor nodes of the node, deduplicated."""
successors = []
seen = set()

Check warning on line 1324 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1323-L1324

Added lines #L1323 - L1324 were not covered by tests
for value in self.outputs:
assert value is not None, "Bug: Output values are not expected to be None"

Check warning on line 1326 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1326

Added line #L1326 was not covered by tests
for usage in value.uses():
if usage.node in seen:
continue
seen.add(usage.node)
successors.append(usage.node)
return successors

Check warning on line 1332 in onnxscript/ir/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_core.py#L1329-L1332

Added lines #L1329 - L1332 were not covered by tests

def replace_input_with(self, index: int, value: Value | None) -> None:
"""Replace an input with a new value."""
if index < 0 or index >= len(self.inputs):
Expand Down Expand Up @@ -1564,7 +1602,7 @@
# Use a collection of (Node, int) to store uses. This is needed
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._uses: dict[tuple[Node, int], None] = {}
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

def __repr__(self) -> str:
Expand Down Expand Up @@ -1599,7 +1637,7 @@
"""The index of the output of the defining node."""
return self._index

def uses(self) -> Collection[tuple[Node, int]]:
def uses(self) -> Collection[Usage]:
"""Return a set of uses of the value.

The set contains tuples of ``(Node, index)`` where the index is the index of the input
Expand All @@ -1612,14 +1650,14 @@

This is an internal method. It should only be called by the Node class.
"""
self._uses[(use, index)] = None
self._uses[Usage(use, index)] = None

def _remove_usage(self, use: Node, index: int) -> None:
"""Remove a node from the uses of this value.

This is an internal method. It should only be called by the Node class.
"""
self._uses.pop((use, index))
self._uses.pop(Usage(use, index))

@property
def name(self) -> str | None:
Expand Down
Loading