Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: fix assertions for certain precompiles #4451

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion tests/functional/builtins/codegen/test_ecrecover.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import contextlib

from eth_account import Account
from eth_account._utils.signing import to_bytes32

from tests.utils import ZERO_ADDRESS
from tests.utils import ZERO_ADDRESS, check_precompile_asserts
from vyper.compiler.settings import OptimizationLevel


def test_ecrecover_test(get_contract):
Expand Down Expand Up @@ -86,3 +89,40 @@ def test_ecrecover() -> bool:
"""
c = get_contract(code)
assert c.test_ecrecover() is True


def test_ecrecover_oog_handling(env, get_contract, tx_failed, optimize, experimental_codegen):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def do_ecrecover(hash: bytes32, v: uint256, r:uint256, s:uint256) -> address:
return ecrecover(hash, v, r, s)
"""
check_precompile_asserts(code)

c = get_contract(code)

h = b"\x35" * 32
local_account = Account.from_key(b"\x46" * 32)
sig = local_account.signHash(h)
v, r, s = sig.v, sig.r, sig.s

assert c.do_ecrecover(h, v, r, s) == local_account.address

gas_used = env.last_result.gas_used

if optimize == OptimizationLevel.NONE and not experimental_codegen:
# if optimizations are off, enough gas is used by the contract
# that the gas provided to ecrecover (63/64ths rule) is enough
# for it to succeed
ctx = contextlib.nullcontext
else:
# in other cases, the gas forwarded is small enough for ecrecover
# to fail with oog, which we handle by reverting.
ctx = tx_failed

with ctx():
# provide enough spare gas for the top-level call to not oog but
# not enough for ecrecover to succeed
c.do_ecrecover(h, v, r, s, gas=gas_used)
60 changes: 59 additions & 1 deletion tests/functional/codegen/types/test_dynamic_array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import itertools
from typing import Any, Callable

import pytest

from tests.utils import decimal_to_int
from tests.utils import check_precompile_asserts, decimal_to_int
from vyper.compiler import compile_code
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
ArgumentException,
ArrayIndexException,
Expand Down Expand Up @@ -1901,3 +1903,59 @@ def foo():
c = get_contract(code)
with tx_failed():
c.foo()


def test_dynarray_copy_oog(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """

@external
def foo(a: DynArray[uint256, 4000]) -> uint256:
b: DynArray[uint256, 4000] = a
return b[0]
"""
check_precompile_asserts(code)

c = get_contract(code)
dynarray = [2] * 4000
assert c.foo(dynarray) == 2

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(dynarray, gas=gas_used)


def test_dynarray_copy_oog2(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def foo(x: String[1000000], y: String[1000000]) -> DynArray[String[1000000], 2]:
z: DynArray[String[1000000], 2] = [x, y]
# Some code
return z
"""
check_precompile_asserts(code)

c = get_contract(code)
calldata0 = "a" * 10
calldata1 = "b" * 1000000
assert c.foo(calldata0, calldata1) == [calldata0, calldata1]

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(calldata0, calldata1, gas=gas_used)
71 changes: 70 additions & 1 deletion tests/functional/codegen/types/test_lists.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextlib
import itertools

import pytest

from tests.utils import decimal_to_int
from tests.utils import check_precompile_asserts, decimal_to_int
from vyper.compiler.settings import OptimizationLevel
from vyper.evm.opcodes import version_check
from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch


Expand Down Expand Up @@ -848,3 +851,69 @@ def foo() -> {return_type}:
return MY_CONSTANT[0][0]
"""
assert_compile_failed(lambda: get_contract(code), TypeMismatch)


def test_array_copy_oog(env, get_contract, tx_failed, optimize, experimental_codegen, request):
# GHSA-vgf2-gvx8-xwc3
code = """
@internal
def bar(x: uint256[3000]) -> uint256[3000]:
a: uint256[3000] = x
return a

@external
def foo(x: uint256[3000]) -> uint256:
s: uint256[3000] = self.bar(x)
return s[0]
"""
check_precompile_asserts(code)

if optimize == OptimizationLevel.NONE and not experimental_codegen:
# fails in get_contract due to code too large
request.node.add_marker(pytest.mark.xfail(strict=True))
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

c = get_contract(code)
array = [2] * 3000
assert c.foo(array) == array[0]

# get the minimum gas for the contract complete execution
gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed
with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(array, gas=gas_used)


def test_array_copy_oog2(env, get_contract, tx_failed, optimize, experimental_codegen, request):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
def foo(x: uint256[2500]) -> uint256:
s: uint256[2500] = x
t: uint256[2500] = s
return t[0]
"""
check_precompile_asserts(code)

if optimize == OptimizationLevel.NONE and not experimental_codegen:
# fails in get_contract due to code too large
request.node.add_marker(pytest.mark.xfail(strict=True))
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

c = get_contract(code)
array = [2] * 2500
assert c.foo(array) == array[0]

# get the minimum gas for the contract complete execution
gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed
with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(array, gas=gas_used)
58 changes: 58 additions & 0 deletions tests/functional/codegen/types/test_string.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import contextlib

import pytest

from tests.utils import check_precompile_asserts
from vyper.evm.opcodes import version_check


def test_string_return(get_contract):
code = """
Expand Down Expand Up @@ -359,3 +364,56 @@ def compare_var_storage_not_equal_false() -> bool:
assert c.compare_var_storage_equal_false() is False
assert c.compare_var_storage_not_equal_true() is True
assert c.compare_var_storage_not_equal_false() is False


def test_string_copy_oog(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def foo(x: String[1000000]) -> String[1000000]:
return x
"""
check_precompile_asserts(code)

c = get_contract(code)
calldata = "a" * 1000000
assert c.foo(calldata) == calldata

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(calldata, gas=gas_used)


def test_string_copy_oog2(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def foo(x: String[1000000]) -> uint256:
y: String[1000000] = x
return len(y)
"""
check_precompile_asserts(code)

c = get_contract(code)
calldata = "a" * 1000000
assert c.foo(calldata) == len(calldata)

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(calldata, gas=gas_used)
16 changes: 16 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

from vyper import ast as vy_ast
from vyper.compiler.phases import CompilerData
from vyper.semantics.analysis.constant_folding import constant_fold
from vyper.utils import DECIMAL_EPSILON, round_towards_zero

Expand All @@ -28,3 +29,18 @@ def parse_and_fold(source_code):
def decimal_to_int(*args):
s = decimal.Decimal(*args)
return round_towards_zero(s / DECIMAL_EPSILON)


def check_precompile_asserts(source_code):
# check deploy IR (which contains runtime IR)
ir_node = CompilerData(source_code).ir_nodes

def _check(ir_node, parent=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be better to assert that one of the nodes is staticcall and assert that it calls a precompile address

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and once we establish this is true, assert that the parent is assert

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm i guess this was motivated by making this general even for cancun. but still i'd change the code and handle this on callsite

if ir_node.value == "staticcall":
precompile_addr = ir_node.args[1]
if isinstance(precompile_addr.value, int) and precompile_addr.value < 10:
assert parent is not None and parent.value == "assert"
for arg in ir_node.args:
_check(arg, ir_node)

_check(ir_node)
2 changes: 1 addition & 1 deletion vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def build_IR(self, expr, args, kwargs, context):
["mstore", add_ofst(input_buf, 32), args[1]],
["mstore", add_ofst(input_buf, 64), args[2]],
["mstore", add_ofst(input_buf, 96), args[3]],
["staticcall", "gas", 1, input_buf, 128, output_buf, 32],
["assert", ["staticcall", "gas", 1, input_buf, 128, output_buf, 32]],
["mload", output_buf],
],
typ=AddressT(),
Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def copy_bytes(dst, src, length, length_bound):
copy_op = ["mcopy", dst, src, length]
gas_bound = _mcopy_gas_bound(length_bound)
else:
copy_op = ["staticcall", "gas", 4, src, length, dst, length]
copy_op = ["assert", ["staticcall", "gas", 4, src, length, dst, length]]
gas_bound = _identity_gas_bound(length_bound)
elif src.location == CALLDATA:
copy_op = ["calldatacopy", dst, src, length]
Expand Down
Loading