Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[venom]: memmerge bytecodesize #4422

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions tests/unit/compiler/venom/test_memmerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,32 @@ def _check_no_change(pre):
_check_pre_post(pre, pre)


def test_memmerging_tmp():
if not version_check(begin="cancun"):
return

pre = """
main:
%1 = mload 352
mstore 448, %1
%2 = mload 416
mstore 64, %2
%3 = mload 448 ; barrier, flushes mload 416 from list of potential copies
mstore 96, %3
stop
"""

post = """
main:
%1 = mload 352
mstore 448, %1
mcopy 64, 416, 64
stop
"""

_check_pre_post(pre, post)


# for parametrizing tests
LOAD_COPY = [("dload", "dloadbytes"), ("calldataload", "calldatacopy")]

Expand Down Expand Up @@ -627,11 +653,20 @@ def test_memmerging_write_after_write():
%3 = mload 32
%4 = mload 132
mstore 1000, %1
mstore 1000, %2 ; BARRIER
mstore 1000, %2 ; partial BARRIER
mstore 1032, %4
mstore 1032, %3 ; BARRIER
"""
_check_no_change(pre)

post = """
_global:
%1 = mload 0
%3 = mload 32
mstore 1000, %1
mcopy 1000, 100, 64
mstore 1032, %3 ; BARRIER
"""
_check_pre_post(pre, post)


def test_memmerging_write_after_write_mstore_and_mcopy():
Expand All @@ -645,9 +680,9 @@ def test_memmerging_write_after_write_mstore_and_mcopy():
pre = """
_global:
%1 = mload 0
%2 = mload 132
%2 = mload 32
mstore 1000, %1
mcopy 1000, 100, 16 ; write barrier
mcopy 1000, 100, 64 ; write barrier
mstore 1032, %2
mcopy 1016, 116, 64
stop
Expand Down
92 changes: 58 additions & 34 deletions vyper/venom/passes/memmerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@
self.analyses_cache.invalidate_analysis(DFGAnalysis)
self.analyses_cache.invalidate_analysis(LivenessAnalysis)

def _optimize_copy(self, bb: IRBasicBlock, copy_opcode: str, load_opcode: str):
for copy in self._copies:
def _optimize_copy(
self, bb: IRBasicBlock, copies: list[_Copy], copy_opcode: str, load_opcode: str
):
for copy in copies:
copy.insts.sort(key=bb.instructions.index)

if copy_opcode == "mcopy":
Expand Down Expand Up @@ -157,34 +159,29 @@

bb.mark_for_removal(inst)

self._copies.clear()
self._loads.clear()
# self._copies.clear()
# self._loads.clear()

def _write_after_write_hazard(self, new_copy: _Copy) -> bool:
def _write_after_write_hazard(self, new_copy: _Copy) -> list[_Copy]:
res = []
for copy in self._copies:
# note, these are the same:
# - new_copy.overwrites(copy.dst_interval())
# - copy.overwrites(new_copy.dst_interval())
if new_copy.overwrites(copy.dst_interval()) and not (
copy.can_merge(new_copy) or new_copy.can_merge(copy)
):
return True
return False
res.append(copy)

def _read_after_write_hazard(self, new_copy: _Copy) -> bool:
new_copies = self._copies + [new_copy]
return res

# new copy would overwrite memory that
# needs to be read to optimize copy
if any(new_copy.overwrites(copy.src_interval()) for copy in new_copies):
return True
def _read_after_write_hazard(self, new_copy: _Copy) -> list[_Copy]:
# new_copies = self._copies + [new_copy]

# existing copies would overwrite memory that the
# new copy would need
if self._overwrites(new_copy.src_interval()):
return True

return False
res = [copy for copy in self._copies if new_copy.overwrites(copy.src_interval())] + [
copy for copy in self._copies if copy.overwrites(new_copy.src_interval())
]
return res

def _find_insertion_point(self, new_copy: _Copy):
return bisect_left(self._copies, new_copy.dst, key=lambda c: c.dst)
Expand Down Expand Up @@ -218,7 +215,24 @@
self._copies = []

def _barrier():
self._optimize_copy(bb, copy_opcode, load_opcode)
self._optimize_copy(bb, self._copies, copy_opcode, load_opcode)
self._copies.clear()
self._loads.clear()

def _load_barrier(interval: _Interval):
copies = [c for c in self._copies if c.overwrites(interval)]
_barrier_for(copies)

def _barrier_for(copies: list[_Copy]):
self._optimize_copy(bb, copies, copy_opcode, load_opcode)
for c in copies:
if c in self._copies:
self._copies.remove(c)
for inst in c.insts:
if inst.opcode == load_opcode:
assert isinstance(inst.output, IRVariable)
if inst.output in self._loads:
del self._loads[inst.output]

# copy in necessary because there is a possibility
# of insertion in optimizations
Expand All @@ -233,7 +247,7 @@

# we will read from this memory so we need to put barier
if not allow_dst_overlaps_src and self._overwrites(read_interval):
_barrier()
_load_barrier(read_interval)

assert inst.output is not None
self._loads[inst.output] = src_op.value
Expand All @@ -254,16 +268,21 @@
assert load_inst is not None # help mypy
n_copy = _Copy(dst.value, src_ptr, 32, [inst, load_inst])

if self._write_after_write_hazard(n_copy):
_barrier()
write_hazards = self._write_after_write_hazard(n_copy)
if len(write_hazards) > 0:
_barrier_for(write_hazards)
# no continue needed, we have not invalidated the loads dict

# check if the new copy does not overwrites existing data
if not allow_dst_overlaps_src and self._read_after_write_hazard(n_copy):
_barrier()
# this continue is necessary because we have invalidated
# the _loads dict, so src_ptr is no longer valid.
continue
if not allow_dst_overlaps_src:
read_hazards = self._read_after_write_hazard(n_copy)
if len(read_hazards) > 0:
_barrier_for(read_hazards)
# this continue is necessary because we have invalidated
# the _loads dict, so src_ptr is no longer valid.
continue
if n_copy.overwrites_self_src():
continue

Check warning on line 285 in vyper/venom/passes/memmerging.py

View check run for this annotation

Codecov / codecov/patch

vyper/venom/passes/memmerging.py#L285

Added line #L285 was not covered by tests
self._add_copy(n_copy)

elif inst.opcode == copy_opcode:
Expand All @@ -274,11 +293,16 @@
length, src, dst = inst.operands
n_copy = _Copy(dst.value, src.value, length.value, [inst])

if self._write_after_write_hazard(n_copy):
_barrier()
write_hazards = self._write_after_write_hazard(n_copy)
if len(write_hazards) > 0:
_barrier_for(write_hazards)
# check if the new copy does not overwrites existing data
if not allow_dst_overlaps_src and self._read_after_write_hazard(n_copy):
_barrier()
if not allow_dst_overlaps_src:
read_hazards = self._read_after_write_hazard(n_copy)
if len(read_hazards) > 0:
_barrier_for(read_hazards)
if n_copy.overwrites_self_src():
continue

Check warning on line 305 in vyper/venom/passes/memmerging.py

View check run for this annotation

Codecov / codecov/patch

vyper/venom/passes/memmerging.py#L305

Added line #L305 was not covered by tests
self._add_copy(n_copy)

elif _volatile_memory(inst):
Expand Down Expand Up @@ -327,7 +351,7 @@
_barrier()
continue
n_copy = _Copy.memzero(dst.value, 32, [inst])
assert not self._write_after_write_hazard(n_copy)
assert len(self._write_after_write_hazard(n_copy)) == 0
self._add_copy(n_copy)
elif inst.opcode == "calldatacopy":
length, var, dst = inst.operands
Expand All @@ -343,7 +367,7 @@
_barrier()
continue
n_copy = _Copy.memzero(dst.value, length.value, [inst])
assert not self._write_after_write_hazard(n_copy)
assert len(self._write_after_write_hazard(n_copy)) == 0
self._add_copy(n_copy)
elif _volatile_memory(inst):
_barrier()
Expand Down
Loading