From 4c84dcf0ebef9b198269e9db2cef9881e026ee00 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 10:22:28 -0800 Subject: [PATCH 01/15] [IR] Create `predecessors()` and `successors()` on `ir.Node` --- onnxscript/ir/_core.py | 47 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index faffde748..7bd0980c6 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -30,6 +30,7 @@ Hashable, Iterable, Iterator, + NamedTuple, OrderedDict, Sequence, SupportsInt, @@ -1055,6 +1056,19 @@ def _quoted(string: str) -> str: return f'"{string}"' + +class Usage(NamedTuple): + """A usage of a value in a node. + + Attributes: + node: The node that uses the value. + index: The index of the value in the node. + """ + + node: Node + index: int + + class Node(_protocols.NodeProtocol, _display.PrettyPrintable): """IR Node. @@ -1293,6 +1307,31 @@ def inputs(self, _: Any) -> None: "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." ) + def predecessors(self) -> Sequence[Node]: + """Return the predecessor nodes of the node, deduplicated.""" + predecessors = [] + seen = set() + for value in self.inputs: + if value is not None and (producer := value.producer()) is not None: + if producer in seen: + continue + seen.add(producer) + predecessors.append(producer) + return predecessors + + def successors(self) -> Sequence[Node]: + """Return the successor nodes of the node, deduplicated.""" + successors = [] + seen = set() + for value in self.outputs: + assert value is not None, "Bug: Output values are not expected to be None" + for usage in value.uses(): + if usage.node in seen: + continue + seen.add(usage.node) + successors.append(usage.node) + return successors + def replace_input_with(self, index: int, value: Value | None) -> None: """Replace an input with a new value.""" if index < 0 or index >= len(self.inputs): @@ -1564,7 +1603,7 @@ def __init__( # Use a collection of (Node, int) to store uses. This is needed # because a single use can use the same value multiple times. # Use a dictionary to preserve insertion order so that the visiting order is deterministic - self._uses: dict[tuple[Node, int], None] = {} + self._uses: dict[Usage, None] = {} self.doc_string = doc_string def __repr__(self) -> str: @@ -1599,7 +1638,7 @@ def index(self) -> int | None: """The index of the output of the defining node.""" return self._index - def uses(self) -> Collection[tuple[Node, int]]: + def uses(self) -> Collection[Usage]: """Return a set of uses of the value. The set contains tuples of ``(Node, index)`` where the index is the index of the input @@ -1612,14 +1651,14 @@ def _add_usage(self, use: Node, index: int) -> None: This is an internal method. It should only be called by the Node class. """ - self._uses[(use, index)] = None + self._uses[Usage(use, index)] = None def _remove_usage(self, use: Node, index: int) -> None: """Remove a node from the uses of this value. This is an internal method. It should only be called by the Node class. """ - self._uses.pop((use, index)) + self._uses.pop(Usage(use, index)) @property def name(self) -> str | None: From 8e80f584c9ba79b70c3467be4d2ccb2307ae7850 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 10:22:41 -0800 Subject: [PATCH 02/15] lint --- onnxscript/ir/_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 7bd0980c6..2e3d31e9d 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1056,7 +1056,6 @@ def _quoted(string: str) -> str: return f'"{string}"' - class Usage(NamedTuple): """A usage of a value in a node. From 3ae78829c3e705efe2585ce72dbeed1c4cfb13b9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 10:51:05 -0800 Subject: [PATCH 03/15] Update onnxscript/ir/_core.py --- onnxscript/ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 2e3d31e9d..4aab3b0f2 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1061,7 +1061,7 @@ class Usage(NamedTuple): Attributes: node: The node that uses the value. - index: The index of the value in the node. + index: The input index of the value in the node. """ node: Node From 17931962f6842c871fc9f954892f97643b31d1a0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 11:00:43 -0800 Subject: [PATCH 04/15] update --- onnxscript/ir/_core.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 4aab3b0f2..727a9e5ca 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1307,29 +1307,23 @@ def inputs(self, _: Any) -> None: ) def predecessors(self) -> Sequence[Node]: - """Return the predecessor nodes of the node, deduplicated.""" - predecessors = [] - seen = set() + """Return the predecessor nodes of the node, deduplicated, in a determinsitic order.""" + # Use the ordered nature of a dictionary to deduplicate the nodes + predecessors = {} for value in self.inputs: if value is not None and (producer := value.producer()) is not None: - if producer in seen: - continue - seen.add(producer) - predecessors.append(producer) - return predecessors + predecessors[producer] = None + return tuple(predecessors) def successors(self) -> Sequence[Node]: - """Return the successor nodes of the node, deduplicated.""" - successors = [] - seen = set() + """Return the successor nodes of the node, deduplicated, in a determinsitic order.""" + # Use the ordered nature of a dictionary to deduplicate the nodes + successors = {} for value in self.outputs: assert value is not None, "Bug: Output values are not expected to be None" for usage in value.uses(): - if usage.node in seen: - continue - seen.add(usage.node) - successors.append(usage.node) - return successors + successors[usage.node] = None + return tuple(successors) def replace_input_with(self, index: int, value: Value | None) -> None: """Replace an input with a new value.""" From 77132687eb06f95ad7bb378f8cc87ad7bec4fe91 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 17:21:06 -0800 Subject: [PATCH 05/15] test --- onnxscript/ir/_core_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 8662a8c01..d5eba6733 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -740,6 +740,8 @@ def setUp(self) -> None: self.v0 = _core.Value() self.v1 = _core.Value() self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3) + self.node_a = _core.Node("test", "TestOpA", inputs=self.node.outputs) + self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs) def test_it_is_hashable(self): self.assertIsInstance(hash(self.node), int) @@ -807,6 +809,9 @@ def test_it_is_added_to_a_graph_if_specified(self): ) self.assertIn(self.node, graph) + def test_predecessors(self): + self.assertEqual(self.node.predecessors(), ()) + # TODO(justinchuby): Test all methods From c30fc531b19f69fe987e330f29eb1952fc08a57a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 17:25:06 -0800 Subject: [PATCH 06/15] uses --- onnxscript/ir/_core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 727a9e5ca..164955ad7 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1637,7 +1637,11 @@ def uses(self) -> Collection[Usage]: The set contains tuples of ``(Node, index)`` where the index is the index of the input of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``. """ - return self._uses.keys() + # Create a tuple for the collection so that iteration on will will not + # be affected when the usage changes during graph mutation. + # This addes a small overhead but is better a user experience than + # having users call tuple(). + return tuple(self._uses.keys()) def _add_usage(self, use: Node, index: int) -> None: """Add a usage of this value. From ca7c1f9f1f7ee3dac9ee09ee392e30d81f6f4016 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 17:37:02 -0800 Subject: [PATCH 07/15] Add consumers --- onnxscript/ir/_core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 164955ad7..56046d130 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1627,6 +1627,10 @@ def producer(self) -> Node | None: """ return self._producer + def consumers(self) -> Sequence[Node]: + """Return the nodes (deduplicated) that consume this value.""" + return tuple({usage.node: None for usage in self._uses}) + def index(self) -> int | None: """The index of the output of the defining node.""" return self._index @@ -1641,7 +1645,7 @@ def uses(self) -> Collection[Usage]: # be affected when the usage changes during graph mutation. # This addes a small overhead but is better a user experience than # having users call tuple(). - return tuple(self._uses.keys()) + return tuple(self._uses) def _add_usage(self, use: Node, index: int) -> None: """Add a usage of this value. From 54348e144c60213d6ea12a902536daf01c3f662f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 17:41:09 -0800 Subject: [PATCH 08/15] type --- onnxscript/ir/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 56046d130..f5fb93e87 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1309,7 +1309,7 @@ def inputs(self, _: Any) -> None: def predecessors(self) -> Sequence[Node]: """Return the predecessor nodes of the node, deduplicated, in a determinsitic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes - predecessors = {} + predecessors: dict[Node, None] = {} for value in self.inputs: if value is not None and (producer := value.producer()) is not None: predecessors[producer] = None @@ -1318,7 +1318,7 @@ def predecessors(self) -> Sequence[Node]: def successors(self) -> Sequence[Node]: """Return the successor nodes of the node, deduplicated, in a determinsitic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes - successors = {} + successors: dict[Node, None] = {} for value in self.outputs: assert value is not None, "Bug: Output values are not expected to be None" for usage in value.uses(): From e20898967c3a874a8edcbcb9b6b7dd1118c3ac0a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 19:22:46 -0800 Subject: [PATCH 09/15] usage --- onnxscript/ir/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index f5fb93e87..2dc775c68 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1061,11 +1061,11 @@ class Usage(NamedTuple): Attributes: node: The node that uses the value. - index: The input index of the value in the node. + idx: The input index of the value in the node. """ node: Node - index: int + idx: int class Node(_protocols.NodeProtocol, _display.PrettyPrintable): From a1f1de613b68570623f7b2a39d32612b088943d6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 19:31:26 -0800 Subject: [PATCH 10/15] Update onnxscript/ir/_core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 2dc775c68..06e5652a4 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1307,7 +1307,7 @@ def inputs(self, _: Any) -> None: ) def predecessors(self) -> Sequence[Node]: - """Return the predecessor nodes of the node, deduplicated, in a determinsitic order.""" + """Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes predecessors: dict[Node, None] = {} for value in self.inputs: From c7d16c5f3f09784eb1be2a50a1ffe4635b3c5a98 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 19:31:33 -0800 Subject: [PATCH 11/15] Update onnxscript/ir/_core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 06e5652a4..243e5fc2b 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1316,7 +1316,7 @@ def predecessors(self) -> Sequence[Node]: return tuple(predecessors) def successors(self) -> Sequence[Node]: - """Return the successor nodes of the node, deduplicated, in a determinsitic order.""" + """Return the successor nodes of the node, deduplicated, in a deterministic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes successors: dict[Node, None] = {} for value in self.outputs: From 85e1585441b1d4ac0c784d823f8109ee31688a92 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 19:32:16 -0800 Subject: [PATCH 12/15] Update onnxscript/ir/_core.py --- onnxscript/ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 243e5fc2b..14d07cb9f 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1643,7 +1643,7 @@ def uses(self) -> Collection[Usage]: """ # Create a tuple for the collection so that iteration on will will not # be affected when the usage changes during graph mutation. - # This addes a small overhead but is better a user experience than + # This adds a small overhead but is better a user experience than # having users call tuple(). return tuple(self._uses) From 4bcfc746ce363704fd7ddc5331e847d4cee9b07e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 19:36:51 -0800 Subject: [PATCH 13/15] test --- onnxscript/ir/_core_test.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index d5eba6733..5fc4e9ac2 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -740,7 +740,7 @@ def setUp(self) -> None: self.v0 = _core.Value() self.v1 = _core.Value() self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3) - self.node_a = _core.Node("test", "TestOpA", inputs=self.node.outputs) + self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]]) self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs) def test_it_is_hashable(self): @@ -811,6 +811,20 @@ def test_it_is_added_to_a_graph_if_specified(self): def test_predecessors(self): self.assertEqual(self.node.predecessors(), ()) + self.assertEqual(self.node_a.predecessors(), (self.node,)) + self.assertEqual(self.node_b.predecessors(), (self.node,)) + + def test_predecessors_are_unique(self): + # node_b has three inputs from node, but only one predecessor + self.assertEqual(self.node_b.predecessors(), self.node_a.predecessors()) + + def test_successors(self): + self.assertEqual(self.node.successors(), (self.node_a, self.node_b)) + self.assertEqual(self.node_a.successors(), ()) + self.assertEqual(self.node_b.successors(), ()) + + def test_successors_are_unique(self): + self.assertEqual(self.node.successors(), (self.node_a, self.node_b)) # TODO(justinchuby): Test all methods From b9f76cc6866cc674da4b3affd33dd61b4c51db8a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 19:44:11 -0800 Subject: [PATCH 14/15] test --- onnxscript/ir/_core_test.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 5fc4e9ac2..b6e1049c5 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -717,6 +717,13 @@ def test_is_dynamic_on_empty_shape(self): class ValueTest(unittest.TestCase): + def setUp(self) -> None: + self.v0 = _core.Value(name="v0") + self.v1 = _core.Value(name="v1") + self.node = _core.Node( + "test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=2 + ) + def test_initialize(self): _ = _core.Value() @@ -732,14 +739,28 @@ def test_meta(self): value.metadata_props["test"] = "any string" self.assertEqual(value.metadata_props["test"], "any string") + def test_producer(self): + self.assertEqual(self.v0.producer(), None) + self.assertEqual(self.v1.producer(), None) + self.assertEqual(self.node.outputs[0].producer(), self.node) + self.assertEqual(self.node.outputs[1].producer(), self.node) + + def test_consumers(self): + self.assertEqual(self.v0.consumers(), (self.node,)) + self.assertEqual(self.v1.consumers(), (self.node,)) + self.assertEqual(self.node.outputs[0].consumers(), ()) + self.assertEqual(self.node.outputs[1].consumers(), ()) + # TODO(justinchuby): Test all methods class NodeTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = _core.Value() - self.v1 = _core.Value() - self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3) + self.v0 = _core.Value(name="v0") + self.v1 = _core.Value(name="v1") + self.node = _core.Node( + "test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=3 + ) self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]]) self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs) @@ -750,7 +771,7 @@ def test_it_is_hashable(self): def test_init_with_values(self): self.assertEqual(self.node.domain, "test") self.assertEqual(self.node.op_type, "TestOp") - self.assertEqual(self.node.inputs, (self.v0, self.v1)) + self.assertEqual(self.node.inputs, (self.v0, self.v1, self.v1)) self.assertEqual(len(self.node.outputs), 3) self.assertEqual(self.node.attributes, {}) From 185b64dd6940f7f30af852704f5112b72c4509b1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 18 Jan 2025 19:54:50 -0800 Subject: [PATCH 15/15] test_predecessors_are_unique --- onnxscript/ir/_core_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index b6e1049c5..9b6cc94f6 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -837,7 +837,7 @@ def test_predecessors(self): def test_predecessors_are_unique(self): # node_b has three inputs from node, but only one predecessor - self.assertEqual(self.node_b.predecessors(), self.node_a.predecessors()) + self.assertEqual(self.node_b.predecessors(), (self.node,)) def test_successors(self): self.assertEqual(self.node.successors(), (self.node_a, self.node_b))