From df252d7f6e2b369d3fae6c40cbf4113fbeb8e56f Mon Sep 17 00:00:00 2001 From: Evgeny Danienko <6655321@bk.ru> Date: Fri, 4 Nov 2022 16:58:51 +0400 Subject: [PATCH 1/6] mutation tests --- .github/mut_blacklist | 6 ++++++ .github/mut_config.yml | 7 +++++++ .github/workflows/main.yml | 21 +++++++++++++++++++++ Makefile | 9 ++++++++- core/ibft_test.go | 28 ++++++++++++++++++++++++++++ deps/dummy.go | 5 +++++ 6 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 .github/mut_blacklist create mode 100644 .github/mut_config.yml create mode 100644 deps/dummy.go diff --git a/.github/mut_blacklist b/.github/mut_blacklist new file mode 100644 index 0000000..2fd62f8 --- /dev/null +++ b/.github/mut_blacklist @@ -0,0 +1,6 @@ +54f1ea08e7395cf6768c102d5677d764 +bc6bebd9df1b01a984c0ffab3c50b8de +d7ce4ca9bca24c65de535f680084fdeb +00a067ff2145e46f285ddf7f0f842788 +cebc341f1ea73b8dc8907be3bb97a1a0 +d18e7ef6485ad32dd1b4935777f433e0 diff --git a/.github/mut_config.yml b/.github/mut_config.yml new file mode 100644 index 0000000..242d572 --- /dev/null +++ b/.github/mut_config.yml @@ -0,0 +1,7 @@ +skip_without_test: false +skip_with_build_tags: false +json_output: true +silent_mode: false +min_msi: 0.89 # should be >0.95 +exclude_dirs: + - messages/proto \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 42cc55f..b95b1fb 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -46,3 +46,24 @@ jobs: - name: Run Go Test run: go test -race -shuffle=on -timeout 28m ./... + + mutating: + name: Mutation tests + runs-on: ubuntu-latest + if: false # skipped, should be restored and fixed + steps: + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 1.18.x + + - name: Checkout code + uses: actions/checkout@v3 + with: + submodules: recursive + + - name: Install dependencies + run: make install-deps + + - name: Mutating testing + run: make mut diff --git a/Makefile b/Makefile index 4d0761e..34dbca6 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,13 @@ -.PHONY: lint +.PHONY: lint mut lint: golangci-lint run -E whitespace -E wsl -E wastedassign -E unconvert -E tparallel -E thelper -E stylecheck -E prealloc \ -E predeclared -E nlreturn -E misspell -E makezero -E lll -E importas -E ifshort -E gosec -E gofmt -E goconst \ -E forcetypeassert -E dogsled -E dupl -E errname -E errorlint -E nolintlint --timeout 2m +install-deps: + go get github.com/JekaMas/go-mutesting/cmd/go-mutesting@v1.1.1 + go install github.com/JekaMas/go-mutesting/... + +mut: + MUTATION_TEST=on go-mutesting --blacklist=".github/mut_blacklist" --config=".github/mut_config.yml" ./... + @echo MSI: `jq '.stats.msi' report.json` diff --git a/core/ibft_test.go b/core/ibft_test.go index 9d9dbf6..1ed2e9b 100644 --- a/core/ibft_test.go +++ b/core/ibft_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "os" "sync" "testing" "time" @@ -2320,3 +2321,30 @@ func TestIBFT_ExtendRoundTimer(t *testing.T) { // Make sure the round timeout was extended assert.Equal(t, additionalTimeout, i.additionalTimeout) } + +// A dummy test that is needed only to run all tests with mutations +func TestDummyForMutations(t *testing.T) { + t.Parallel() + + if len(os.Getenv("MUTATION_TEST")) == 0 { + t.SkipNow() + } + + tests := map[string]func(t *testing.T){ + "TestConsensus_ValidFlow": TestConsensus_ValidFlow, + "TestConsensus_InvalidBlock": TestConsensus_InvalidBlock, + "TestDropMaxFaultyPlusOne": TestDropMaxFaultyPlusOne, + "TestDropMaxFaulty": TestDropMaxFaulty, + + // property bases tests + //"TestProperty_MajorityHonestNodes": TestProperty_MajorityHonestNodes, + } + + for name, test := range tests { + t.Parallel() + + test := test + + t.Run(name, test) + } +} diff --git a/deps/dummy.go b/deps/dummy.go new file mode 100644 index 0000000..e67dd3e --- /dev/null +++ b/deps/dummy.go @@ -0,0 +1,5 @@ +package deps + +import ( + _ "github.com/JekaMas/go-mutesting" +) From 96be8bf756063c9288e98b4cd12cecc86a801d18 Mon Sep 17 00:00:00 2001 From: Evgeny Danienko <6655321@bk.ru> Date: Fri, 4 Nov 2022 16:59:10 +0400 Subject: [PATCH 2/6] update deps --- go.mod | 4 ++++ go.sum | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c479761..f3d8d84 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/0xPolygon/go-ibft go 1.18 require ( + github.com/JekaMas/go-mutesting v1.1.1 github.com/google/uuid v1.3.0 github.com/stretchr/testify v1.8.1 go.uber.org/goleak v1.2.0 @@ -14,5 +15,8 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect + golang.org/x/tools v0.1.12 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9565209..948c206 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/JekaMas/go-mutesting v1.1.1 h1:jdgkCgMRxFXvluWGTrkSKjdBPeEczkYQoUM5P4MfaOw= +github.com/JekaMas/go-mutesting v1.1.1/go.mod h1:4MvW+K744lDEkfvAP0hazhPthWMcnsQur7QhLebcb7A= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -22,7 +24,12 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= -golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= From 78ca14effc6e1795d6209a98b915f7b77cc7e3f7 Mon Sep 17 00:00:00 2001 From: Evgeny Danienko <6655321@bk.ru> Date: Fri, 11 Nov 2022 13:06:30 +0400 Subject: [PATCH 3/6] mut config --- .github/mut_config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/mut_config.yml b/.github/mut_config.yml index 242d572..ac82b69 100644 --- a/.github/mut_config.yml +++ b/.github/mut_config.yml @@ -2,6 +2,6 @@ skip_without_test: false skip_with_build_tags: false json_output: true silent_mode: false -min_msi: 0.89 # should be >0.95 +min_msi: 0.90 # should be >0.95 exclude_dirs: - - messages/proto \ No newline at end of file + - messages/proto From 00affabfdcf78b94e2593d7e8c4a8022693c3058 Mon Sep 17 00:00:00 2001 From: Evgeny Danienko <6655321@bk.ru> Date: Fri, 11 Nov 2022 13:08:47 +0400 Subject: [PATCH 4/6] rearrange --- .github/workflows/main.yml | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ffde71c..304a76c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -47,8 +47,10 @@ jobs: - name: Run Go Test with race run: go test -race -shuffle=on -timeout 2m ./... - reproducible-builds: + mutating: + name: Mutation tests runs-on: ubuntu-latest + if: false # skipped, should be restored and fixed steps: - name: Install Go uses: actions/setup-go@v3 @@ -60,19 +62,14 @@ jobs: with: submodules: recursive - - name: Reproducible build test - run: | - make builds-dummy - shasum -a256 ./build/ibft1 | cut -d " " -f 1 > ibft1.sha256 - shasum -a256 ./build/ibft2 | cut -d " " -f 1 > ibft2.sha256 - if ! cmp ibft1.sha256 ibft2.sha256; then - echo >&2 "Reproducible build broken"; cat ibft1.sha256; cat ibft2.sha256; exit 1 - fi + - name: Install dependencies + run: make install-deps - mutating: - name: Mutation tests + - name: Mutating testing + run: make mut + + reproducible-builds: runs-on: ubuntu-latest - if: false # skipped, should be restored and fixed steps: - name: Install Go uses: actions/setup-go@v3 @@ -84,8 +81,11 @@ jobs: with: submodules: recursive - - name: Install dependencies - run: make install-deps - - - name: Mutating testing - run: make mut + - name: Reproducible build test + run: | + make builds-dummy + shasum -a256 ./build/ibft1 | cut -d " " -f 1 > ibft1.sha256 + shasum -a256 ./build/ibft2 | cut -d " " -f 1 > ibft2.sha256 + if ! cmp ibft1.sha256 ibft2.sha256; then + echo >&2 "Reproducible build broken"; cat ibft1.sha256; cat ibft2.sha256; exit 1 + fi From da4fb4f1a52624f31f6212597ed1c25877a20c3e Mon Sep 17 00:00:00 2001 From: Evgeny Danienko <6655321@bk.ru> Date: Fri, 11 Nov 2022 13:11:48 +0400 Subject: [PATCH 5/6] fix --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 9dfda06..1ab1163 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ lint-all: install-deps: go get github.com/JekaMas/go-mutesting/cmd/go-mutesting@v1.1.1 - go install github.com/JekaMas/go-mutesting/... + go install github.com/JekaMas/go-mutesting/... curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b ./build/bin v1.50.1 mut: From 428332ec5bdc056b1408beff36eaec2299af85fb Mon Sep 17 00:00:00 2001 From: kourin Date: Fri, 20 Jan 2023 19:45:15 +0900 Subject: [PATCH 6/6] Add test cases to improve mutation score indicator (#60) * Minor fixes in ibft and tests (#45) * Add rapid test with bad proposal coming from byzantine node (#44) * Implement event generator for rapid testing (#46) * EVM-220 TestClusterBlockSync/BLS fails in voting power branch (#48) * Added per round event-based setup in rapid tests (#47) * Remove redundant changeState (#49) * Fix Wrong Round Value in Validation of roundsAndPreparedBlockHashes (#51) Fix round in roundsAndP reparedBlockHashes * Audit improvements (#50) * Audit improvements * Add unit tests for EventManager to improve MSI * Add unit tests for message helper to improve MSI * Byzantine tests (#56) * Byzantine tests * Add unit tests for Messages to improve MSI * Fix lint error * fix lint errors only for the codes that changed in the PR * fixed some stuck test * Revert disabling function-length check in golangci * Revert "Revert disabling function-length check in golangci" This reverts commit 415633e8c011bf5481448ecdf9f6b31ee1118a62. Co-authored-by: Roman Behma <13855864+begmaroman@users.noreply.github.com> Co-authored-by: Igor Crevar Co-authored-by: Vuk Gavrilovic <114920311+trimixlover@users.noreply.github.com> --- .github/workflows/main.yml | 4 +- .golangci.yml | 15 +- core/byzantine_test.go | 459 +++++++++++++++++++++++++++++ core/consensus_test.go | 108 ++----- core/helpers_test.go | 13 +- core/ibft.go | 101 +++++-- core/ibft_test.go | 521 +++++++++++++++++++++++++++------ core/mock_test.go | 121 ++++++-- core/rapid_test.go | 420 +++++++++++++------------- messages/event_manager.go | 6 +- messages/event_manager_test.go | 224 +++++++++++++- messages/helpers.go | 24 +- messages/helpers_test.go | 203 +++++++++++-- messages/messages.go | 4 + messages/messages_test.go | 496 +++++++++++++++++++++++++++---- 15 files changed, 2171 insertions(+), 548 deletions(-) create mode 100644 core/byzantine_test.go diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 304a76c..5e7ad51 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,7 @@ jobs: submodules: recursive - name: Go test - run: go test -covermode=atomic -shuffle=on -coverprofile coverage.out -timeout 2m ./... + run: go test -test.short -covermode=atomic -shuffle=on -coverprofile coverage.out -timeout 10m ./... - name: Upload coverage file to Codecov uses: codecov/codecov-action@v3 @@ -45,7 +45,7 @@ jobs: submodules: recursive - name: Run Go Test with race - run: go test -race -shuffle=on -timeout 2m ./... + run: go test -test.short -race -shuffle=on -timeout 10m ./... mutating: name: Mutation tests diff --git a/.golangci.yml b/.golangci.yml index 7f4683f..ad47552 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -58,12 +58,10 @@ linters: - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unused # Checks Go code for unused constants, variables, functions and types - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - durationcheck # check for two durations multiplied together - errchkjson - gochecknoglobals # check that no global variables exist - goerr113 # Golang linter to check the errors handling expressions - - gomnd # An analyzer to detect magic numbers. - ireturn # Accept Interfaces, Return Concrete Types - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - promlinter # Check Prometheus metrics naming via promlint @@ -167,7 +165,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#add-constant - name: add-constant severity: warning - disabled: false + disabled: true arguments: - maxLitCount: "3" allowStrs: '""' @@ -206,7 +204,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#cognitive-complexity - name: cognitive-complexity severity: warning - disabled: false + disabled: true arguments: [ 7 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#confusing-naming - name: confusing-naming @@ -233,7 +231,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#cyclomatic - name: cyclomatic severity: warning - disabled: false + disabled: true arguments: [ 3 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#datarace - name: datarace @@ -310,7 +308,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#function-length - name: function-length severity: warning - disabled: false + disabled: true arguments: [ 10, 0 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#get-return - name: get-return @@ -495,6 +493,11 @@ issues: - gosec - unparam - lll + - containedctx + - goerr113 + - revive + - gochecknoglobals + - exhaustive include: - EXC0012 # Exported (.+) should have comment( \(or a comment on this block\))? or be unexported - EXC0013 # Package comment should be of the form "(.+)... diff --git a/core/byzantine_test.go b/core/byzantine_test.go new file mode 100644 index 0000000..179b23f --- /dev/null +++ b/core/byzantine_test.go @@ -0,0 +1,459 @@ +package core + +import ( + "bytes" + "github.com/0xPolygon/go-ibft/messages/proto" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestByzantineBehaviour(t *testing.T) { + t.Parallel() + + t.Run("malicious hash in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadHashPrePrepareMessageFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(20*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious hash in prepare", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(c.isProposer) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrepareMessageFn(createBadHashPrepareMessageFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(10*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(10*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadRoundPrePrepareMessageFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + // Max tolerant byzantine + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(40*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadRoundPrePrepareMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and bad hash in proposal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrePrepareMessageFn(createBadHashPrePrepareMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and bad hash in prepare", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildPrepareMessageFn(createBadHashPrepareMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) + + t.Run("malicious +1 round in rcc and bad commit seal", func(t *testing.T) { + t.Parallel() + + cluster := newCluster( + 6, + func(c *cluster) { + for _, node := range c.nodes { + currentNode := node + + backendBuilder := mockBackendBuilder{} + backendBuilder.withProposerFn(createForcedRCProposerFn(c)) + backendBuilder.withIDFn(currentNode.addr) + backendBuilder.withBuildCommitMessageFn(createBadCommitMessageFn(currentNode)) + backendBuilder.withBuildRoundChangeMessageFn(createBadRoundRoundChangeFn(currentNode)) + backendBuilder.withHasQuorumFn(c.hasQuorumFn) + + node.core = NewIBFT( + mockLogger{}, + backendBuilder.build(currentNode), + &mockTransport{multicastFn: c.gossip}, + ) + } + }, + ) + + err := cluster.progressToHeight(20*time.Second, 1) + assert.NoError(t, err, "unable to reach height: %w", err) + assert.Equal(t, uint64(1), cluster.latestHeight) + + cluster.makeNByzantine(int(cluster.maxFaulty())) + assert.NoError(t, cluster.progressToHeight(30*time.Second, 2)) + assert.Equal(t, uint64(2), cluster.latestHeight) + }) +} + +func createBadRoundRoundChangeFn(node *node) buildRoundChangeMessageDelegate { + return func(proposal []byte, + rcc *proto.PreparedCertificate, + view *proto.View) *proto.Message { + if node.byzantine { + view.Round++ + } + + return buildBasicRoundChangeMessage( + proposal, + rcc, + view, + node.address, + ) + } +} + +func createBadRoundPrePrepareMessageFn(node *node) buildPrePrepareMessageDelegate { + return func( + proposal []byte, + certificate *proto.RoundChangeCertificate, + view *proto.View, + ) *proto.Message { + if node.byzantine { + view.Round++ + } + + return buildBasicPreprepareMessage( + proposal, + validProposalHash, + certificate, + node.address, + view, + ) + } +} + +func createBadHashPrePrepareMessageFn(node *node) buildPrePrepareMessageDelegate { + return func(proposal []byte, + rcc *proto.RoundChangeCertificate, + view *proto.View) *proto.Message { + proposalHash := validProposalHash + if node.byzantine { + proposalHash = []byte("invalid proposal hash") + } + + return buildBasicPreprepareMessage( + proposal, + proposalHash, + rcc, + node.address, + view, + ) + } +} + +func createBadHashPrepareMessageFn(node *node) buildPrepareMessageDelegate { + return func(_ []byte, view *proto.View) *proto.Message { + proposalHash := validProposalHash + if node.byzantine { + proposalHash = []byte("invalid proposal hash") + } + + return buildBasicPrepareMessage( + proposalHash, + node.address, + view, + ) + } +} + +func createForcedRCProposerFn(c *cluster) isProposerDelegate { + return func(from []byte, height uint64, round uint64) bool { + if round == 0 { + return false + } + + return bytes.Equal( + from, + c.addresses()[int(round)%len(c.addresses())], + ) + } +} + +func createBadCommitMessageFn(node *node) buildCommitMessageDelegate { + return func(_ []byte, view *proto.View) *proto.Message { + committedSeal := validCommittedSeal + if node.byzantine { + committedSeal = []byte("invalid committed seal") + } + + return buildBasicCommitMessage( + validProposalHash, + committedSeal, + node.address, + view, + ) + } +} + +type mockBackendBuilder struct { + isProposerFn isProposerDelegate + + idFn idDelegate + + buildPrePrepareMessageFn buildPrePrepareMessageDelegate + buildPrepareMessageFn buildPrepareMessageDelegate + buildCommitMessageFn buildCommitMessageDelegate + buildRoundChangeMessageFn buildRoundChangeMessageDelegate + + hasQuorumFn hasQuorumDelegate +} + +func (b *mockBackendBuilder) withProposerFn(f isProposerDelegate) { + b.isProposerFn = f +} + +func (b *mockBackendBuilder) withBuildPrePrepareMessageFn(f buildPrePrepareMessageDelegate) { + b.buildPrePrepareMessageFn = f +} + +func (b *mockBackendBuilder) withBuildPrepareMessageFn(f buildPrepareMessageDelegate) { + b.buildPrepareMessageFn = f +} + +func (b *mockBackendBuilder) withBuildCommitMessageFn(f buildCommitMessageDelegate) { + b.buildCommitMessageFn = f +} + +func (b *mockBackendBuilder) withBuildRoundChangeMessageFn(f buildRoundChangeMessageDelegate) { + b.buildRoundChangeMessageFn = f +} + +func (b *mockBackendBuilder) withIDFn(f idDelegate) { + b.idFn = f +} + +func (b *mockBackendBuilder) withHasQuorumFn(f hasQuorumDelegate) { + b.hasQuorumFn = f +} + +func (b *mockBackendBuilder) build(node *node) *mockBackend { + if b.buildPrePrepareMessageFn == nil { + b.buildPrePrepareMessageFn = node.buildPrePrepare + } + + if b.buildPrepareMessageFn == nil { + b.buildPrepareMessageFn = node.buildPrepare + } + + if b.buildCommitMessageFn == nil { + b.buildCommitMessageFn = node.buildCommit + } + + if b.buildRoundChangeMessageFn == nil { + b.buildRoundChangeMessageFn = node.buildRoundChange + } + + return &mockBackend{ + isValidBlockFn: isValidProposal, + isValidProposalHashFn: isValidProposalHash, + isValidSenderFn: nil, + isValidCommittedSealFn: nil, + isProposerFn: b.isProposerFn, + idFn: b.idFn, + + buildProposalFn: buildValidProposal, + buildPrePrepareMessageFn: b.buildPrePrepareMessageFn, + buildPrepareMessageFn: b.buildPrepareMessageFn, + buildCommitMessageFn: b.buildCommitMessageFn, + buildRoundChangeMessageFn: b.buildRoundChangeMessageFn, + insertBlockFn: nil, + hasQuorumFn: b.hasQuorumFn, + } +} diff --git a/core/consensus_test.go b/core/consensus_test.go index 62c55a5..c48bb89 100644 --- a/core/consensus_test.go +++ b/core/consensus_test.go @@ -147,18 +147,17 @@ func commonHasQuorumFn(numNodes uint64) func(blockNumber uint64, messages []*pro func TestConsensus_ValidFlow(t *testing.T) { t.Parallel() - var multicastFn func(message *proto.Message) + var ( + multicastFn func(message *proto.Message) - proposal := []byte("proposal") - proposalHash := []byte("proposal hash") - committedSeal := []byte("seal") - numNodes := uint64(4) - nodes := generateNodeAddresses(numNodes) - insertedBlocks := make([][]byte, numNodes) + numNodes = uint64(4) + nodes = generateNodeAddresses(numNodes) + insertedBlocks = make([][]byte, numNodes) + ) // commonTransportCallback is the common method modification // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { + commonTransportCallback := func(transport *mockTransport, _ int) { transport.multicastFn = func(message *proto.Message) { multicastFn(message) } @@ -182,12 +181,12 @@ func TestConsensus_ValidFlow(t *testing.T) { // Make sure the proposal is valid if it matches what node 0 proposed backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, proposal) + return bytes.Equal(newProposal, correctRoundMessage.proposal) } // Make sure the proposal hash matches backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, proposal) && bytes.Equal(ph, proposalHash) + return bytes.Equal(p, correctRoundMessage.proposal) && bytes.Equal(ph, correctRoundMessage.hash) } // Make sure the preprepare message is built correctly @@ -198,7 +197,7 @@ func TestConsensus_ValidFlow(t *testing.T) { ) *proto.Message { return buildBasicPreprepareMessage( proposal, - proposalHash, + correctRoundMessage.hash, certificate, nodes[nodeIndex], view) @@ -206,12 +205,12 @@ func TestConsensus_ValidFlow(t *testing.T) { // Make sure the prepare message is built correctly backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(proposalHash, nodes[nodeIndex], view) + return buildBasicPrepareMessage(correctRoundMessage.hash, nodes[nodeIndex], view) } // Make sure the commit message is built correctly backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage(proposalHash, committedSeal, nodes[nodeIndex], view) + return buildBasicCommitMessage(correctRoundMessage.hash, correctRoundMessage.seal, nodes[nodeIndex], view) } // Make sure the round change message is built correctly @@ -227,44 +226,19 @@ func TestConsensus_ValidFlow(t *testing.T) { backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { insertedBlocks[nodeIndex] = proposal } - } - var ( - backendCallbackMap = map[int]backendConfigCallback{ - 0: func(backend *mockBackend) { - // Execute the common backend setup - commonBackendCallback(backend, 0) - - // Set the proposal creation method for node 0, since - // they are the proposer - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposal - } - }, - 1: func(backend *mockBackend) { - commonBackendCallback(backend, 1) - }, - 2: func(backend *mockBackend) { - commonBackendCallback(backend, 2) - }, - 3: func(backend *mockBackend) { - commonBackendCallback(backend, 3) - }, - } - transportCallbackMap = map[int]transportConfigCallback{ - 0: commonTransportCallback, - 1: commonTransportCallback, - 2: commonTransportCallback, - 3: commonTransportCallback, + // Set the proposal creation method + backend.buildProposalFn = func(_ *proto.View) []byte { + return correctRoundMessage.proposal } - ) + } // Create the mock cluster cluster := newMockCluster( numNodes, - backendCallbackMap, + commonBackendCallback, nil, - transportCallbackMap, + commonTransportCallback, ) // Set the multicast callback to relay the message @@ -281,7 +255,7 @@ func TestConsensus_ValidFlow(t *testing.T) { // Make sure the inserted blocks match what node 0 proposed for _, block := range insertedBlocks { - assert.True(t, bytes.Equal(block, proposal)) + assert.True(t, bytes.Equal(block, correctRoundMessage.proposal)) } } @@ -316,7 +290,7 @@ func TestConsensus_InvalidBlock(t *testing.T) { // commonTransportCallback is the common method modification // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { + commonTransportCallback := func(transport *mockTransport, _ int) { transport.multicastFn = func(message *proto.Message) { multicastFn(message) } @@ -394,45 +368,19 @@ func TestConsensus_InvalidBlock(t *testing.T) { backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { insertedBlocks[nodeIndex] = proposal } - } - - var ( - backendCallbackMap = map[int]backendConfigCallback{ - 0: func(backend *mockBackend) { - commonBackendCallback(backend, 0) - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposals[0] - } - }, - 1: func(backend *mockBackend) { - commonBackendCallback(backend, 1) - - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposals[1] - } - }, - 2: func(backend *mockBackend) { - commonBackendCallback(backend, 2) - }, - 3: func(backend *mockBackend) { - commonBackendCallback(backend, 3) - }, - } - transportCallbackMap = map[int]transportConfigCallback{ - 0: commonTransportCallback, - 1: commonTransportCallback, - 2: commonTransportCallback, - 3: commonTransportCallback, + // Build proposal function + backend.buildProposalFn = func(_ *proto.View) []byte { + return proposals[nodeIndex] } - ) + } // Create the mock cluster cluster := newMockCluster( numNodes, - backendCallbackMap, + commonBackendCallback, nil, - transportCallbackMap, + commonTransportCallback, ) // Set the base timeout to be lower than usual @@ -440,9 +388,7 @@ func TestConsensus_InvalidBlock(t *testing.T) { // Set the multicast callback to relay the message // to the entire cluster - multicastFn = func(message *proto.Message) { - cluster.pushMessage(message) - } + multicastFn = cluster.pushMessage // Start the main run loops cluster.runSequence(1) diff --git a/core/helpers_test.go b/core/helpers_test.go index 14de1b3..1d4c6a0 100644 --- a/core/helpers_test.go +++ b/core/helpers_test.go @@ -37,9 +37,10 @@ func isValidProposalHash(proposal, proposalHash []byte) bool { } type node struct { - core *IBFT - address []byte - offline bool + core *IBFT + address []byte + offline bool + byzantine bool } func (n *node) addr() []byte { @@ -217,6 +218,12 @@ func (c *cluster) maxFaulty() uint64 { return (uint64(len(c.nodes)) - 1) / 3 } +func (c *cluster) makeNByzantine(num int) { + for i := 0; i < num; i++ { + c.nodes[i].byzantine = true + } +} + func (c *cluster) stopN(num int) { for i := 0; i < num; i++ { c.nodes[i].offline = true diff --git a/core/ibft.go b/core/ibft.go index 58cb90b..368573b 100644 --- a/core/ibft.go +++ b/core/ibft.go @@ -12,17 +12,21 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +// Logger represents the logger behaviour type Logger interface { Info(msg string, args ...interface{}) Debug(msg string, args ...interface{}) Error(msg string, args ...interface{}) } +// Messages represents the message managing behaviour type Messages interface { // Messages modifiers // AddMessage(message *proto.Message) PruneByHeight(height uint64) + SignalEvent(message *proto.Message) + // Messages fetchers // GetValidMessages( view *proto.View, @@ -36,10 +40,13 @@ type Messages interface { Unsubscribe(id messages.SubscriptionID) } +const ( + round0Timeout = 10 * time.Second + roundFactorBase = float64(2) +) + var ( errTimeoutExpired = errors.New("round timeout expired") - - round0Timeout = 10 * time.Second ) // IBFT represents a single instance of the IBFT state machine @@ -123,14 +130,10 @@ func NewIBFT( func (i *IBFT) startRoundTimer(ctx context.Context, round uint64) { defer i.wg.Done() - var ( - duration = int(i.baseRoundTimeout) - roundFactor = int(math.Pow(float64(2), float64(round))) - roundTimeout = time.Duration(duration * roundFactor) - ) + roundTimeout := getRoundTimeout(i.baseRoundTimeout, i.additionalTimeout, round) // Create a new timer instance - timer := time.NewTimer(roundTimeout + i.additionalTimeout) + timer := time.NewTimer(roundTimeout) select { case <-ctx.Done(): @@ -325,6 +328,7 @@ func (i *IBFT) RunSequence(ctx context.Context, h uint64) { i.moveToNewRound(ev.round) i.acceptProposal(ev.proposalMessage) i.state.setRoundStarted(true) + i.sendPrepareMessage(view) case round := <-i.roundCertificate: teardown() i.log.Info("received future RCC", "round", round) @@ -344,7 +348,7 @@ func (i *IBFT) RunSequence(ctx context.Context, h uint64) { teardown() return - case <-ctx.Done(): + case <-ctxRound.Done(): teardown() i.log.Debug("sequence cancelled") @@ -566,10 +570,8 @@ func (i *IBFT) runNewRound(ctx context.Context) error { continue } - // Accept the proposal since it's valid - i.acceptProposal(proposalMessage) - // Multicast the PREPARE message + i.state.setProposalMessage(proposalMessage) i.sendPrepareMessage(view) i.log.Debug("prepare message multicasted") @@ -663,12 +665,31 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { return false } + if !messages.HasUniqueSenders(certificate.RoundChangeMessages) { + return false + } + // Make sure all messages in the RCC are valid Round Change messages for _, rc := range certificate.RoundChangeMessages { // Make sure the message is a Round Change message if rc.Type != proto.MessageType_ROUND_CHANGE { return false } + + // Height of the message matches height of the proposal + if rc.View.Height != height { + return false + } + + // Round of the message matches round of the proposal + if rc.View.Round != round { + return false + } + + // Sender of RCC is valid + if !i.backend.IsValidSender(rc) { + return false + } } // Extract possible rounds and their corresponding @@ -681,14 +702,14 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { roundsAndPreparedBlockHashes := make([]roundHashTuple, 0) for _, rcMessage := range rcc.RoundChangeMessages { - certificate := messages.ExtractLatestPC(rcMessage) + cert := messages.ExtractLatestPC(rcMessage) // Check if there is a certificate, and if it's a valid PC - if certificate != nil && i.validPC(certificate, msg.View.Round, height) { - hash := messages.ExtractProposalHash(certificate.ProposalMessage) + if cert != nil && i.validPC(cert, msg.View.Round, height) { + hash := messages.ExtractProposalHash(cert.ProposalMessage) roundsAndPreparedBlockHashes = append(roundsAndPreparedBlockHashes, roundHashTuple{ - round: rcMessage.View.Round, + round: cert.ProposalMessage.View.Round, hash: hash, }) } @@ -700,12 +721,12 @@ func (i *IBFT) validateProposal(msg *proto.Message, view *proto.View) bool { // Find the max round var ( - maxRound uint64 = 0 - expectedHash []byte = nil + maxRound uint64 + expectedHash []byte ) for _, tuple := range roundsAndPreparedBlockHashes { - if tuple.round > maxRound { + if tuple.round >= maxRound { maxRound = tuple.round expectedHash = tuple.hash } @@ -998,6 +1019,14 @@ func (i *IBFT) AddMessage(message *proto.Message) { // Check if the message should even be considered if i.isAcceptableMessage(message) { i.messages.AddMessage(message) + + msgs := i.messages.GetValidMessages( + message.View, + message.Type, + func(_ *proto.Message) bool { return true }) + if i.backend.HasQuorum(message.View.Height, msgs, message.Type) { + i.messages.SignalEvent(message) + } } } @@ -1044,6 +1073,8 @@ func (i *IBFT) validPC( return false } + // Order of messages is important! + // Message with type of MessageType_PREPREPARE must be the first element of allMessages slice allMessages := append( []*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages..., @@ -1086,6 +1117,11 @@ func (i *IBFT) validPC( return false } + // Make sure all have the same round + if !messages.AllHaveSameRound(allMessages) { + return false + } + // Make sure the proposal message is sent by the proposer // for the round proposal := certificate.ProposalMessage @@ -1093,12 +1129,22 @@ func (i *IBFT) validPC( return false } + // Make sure that the proposal sender is valid + if !i.backend.IsValidSender(proposal) { + return false + } + // Make sure the Prepare messages are validators, apart from the proposer for _, message := range certificate.PrepareMessages { // Make sure the sender is part of the validator set if !i.backend.IsValidSender(message) { return false } + + // Make sure the current node is not the proposer + if i.backend.IsProposer(message.From, message.View.Height, message.View.Round) { + return false + } } return true @@ -1142,3 +1188,20 @@ func (i *IBFT) sendCommitMessage(view *proto.View) { ), ) } + +// getRoundTimeout creates a round timeout based on the base timeout and the current round. +// Exponentially increases timeout depending on the round number. +// For instance: +// - round 1: 1 sec +// - round 2: 2 sec +// - round 3: 4 sec +// - round 4: 8 sec +func getRoundTimeout(baseRoundTimeout, additionalTimeout time.Duration, round uint64) time.Duration { + var ( + duration = int(baseRoundTimeout) + roundFactor = int(math.Pow(roundFactorBase, float64(round))) + roundTimeout = time.Duration(duration * roundFactor) + ) + + return roundTimeout + additionalTimeout +} diff --git a/core/ibft_test.go b/core/ibft_test.go index 08c1cc8..dd435b3 100644 --- a/core/ibft_test.go +++ b/core/ibft_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "math/rand" "os" "sync" "testing" @@ -156,7 +157,7 @@ func generateFilledRCMessages( proposal, proposalHash []byte) []*proto.Message { // Generate random RC messages - roundChangeMessages := generateMessages(quorum, proto.MessageType_ROUND_CHANGE) + roundChangeMessages := generateMessagesWithUniqueSender(quorum, proto.MessageType_ROUND_CHANGE) prepareMessages := generateMessages(quorum-1, proto.MessageType_PREPARE) // Fill up the prepare message hashes @@ -307,8 +308,6 @@ func TestRunNewRound_Proposer(t *testing.T) { var ( multicastedPreprepare *proto.Message = nil multicastedPrepare *proto.Message = nil - proposalHash = []byte("proposal hash") - proposal = []byte("proposal") notifyCh = make(chan uint64, 1) log = mockLogger{} @@ -328,7 +327,7 @@ func TestRunNewRound_Proposer(t *testing.T) { }, hasQuorumFn: defaultHasQuorumFn(quorum), buildProposalFn: func(_ *proto.View) []byte { - return proposal + return correctRoundMessage.proposal }, buildPrepareMessageFn: func(_ []byte, view *proto.View) *proto.Message { return &proto.Message{ @@ -336,7 +335,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -351,8 +350,8 @@ func TestRunNewRound_Proposer(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -402,7 +401,7 @@ func TestRunNewRound_Proposer(t *testing.T) { assert.Equal(t, multicastedPreprepare, i.state.proposalMessage) // Make sure the correct proposal value was multicasted - assert.True(t, proposalMatches(proposal, multicastedPreprepare)) + assert.True(t, proposalMatches(correctRoundMessage.proposal, multicastedPreprepare)) // Make sure the prepare message was not multicasted assert.Nil(t, multicastedPrepare) @@ -415,7 +414,6 @@ func TestRunNewRound_Proposer(t *testing.T) { t.Parallel() lastPreparedProposedBlock := []byte("last prepared block") - proposalHash := []byte("proposal hash") quorum := uint64(4) ctx, cancelFn := context.WithCancel(context.Background()) @@ -426,7 +424,7 @@ func TestRunNewRound_Proposer(t *testing.T) { for index, message := range prepareMessages { message.Payload = &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, } @@ -451,7 +449,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ Proposal: lastPreparedProposedBlock, - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, Certificate: nil, }, }, @@ -460,6 +458,7 @@ func TestRunNewRound_Proposer(t *testing.T) { } var ( + proposerID = []byte("unique node") multicastedPreprepare *proto.Message = nil multicastedPrepare *proto.Message = nil proposal = []byte("proposal") @@ -476,9 +475,9 @@ func TestRunNewRound_Proposer(t *testing.T) { } }} backend = mockBackend{ - idFn: func() []byte { return nil }, - isProposerFn: func(_ []byte, _ uint64, _ uint64) bool { - return true + idFn: func() []byte { return proposerID }, + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return bytes.Equal(proposerID, proposer) }, hasQuorumFn: defaultHasQuorumFn(quorum), buildProposalFn: func(_ *proto.View) []byte { @@ -490,7 +489,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -506,7 +505,7 @@ func TestRunNewRound_Proposer(t *testing.T) { Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ Proposal: proposal, - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, Certificate: certificate, }, }, @@ -573,8 +572,6 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { ctx, cancelFn := context.WithCancel(context.Background()) var ( - proposal = []byte("new block") - proposalHash = []byte("proposal hash") proposer = []byte("proposer") multicastedPrepare *proto.Message = nil notifyCh = make(chan uint64, 1) @@ -600,7 +597,7 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -635,7 +632,7 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, + Proposal: correctRoundMessage.proposal, }, }, }, @@ -661,10 +658,10 @@ func TestRunNewRound_Validator_Zero(t *testing.T) { assert.Equal(t, prepare, i.state.name) // Make sure the accepted proposal is the one that was sent out - assert.Equal(t, proposal, i.state.getProposal()) + assert.Equal(t, correctRoundMessage.proposal, i.state.getProposal()) // Make sure the correct proposal hash was multicasted - assert.True(t, prepareHashMatches(proposalHash, multicastedPrepare)) + assert.True(t, prepareHashMatches(correctRoundMessage.hash, multicastedPrepare)) } // TestRunNewRound_Validator_NonZero validates the behavior @@ -673,8 +670,6 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { t.Parallel() quorum := uint64(4) - proposalHash := []byte("proposal hash") - proposal := []byte("new block") proposer := []byte("proposer") generateProposalWithNoPrevious := func() *proto.Message { @@ -690,8 +685,8 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, Certificate: &proto.RoundChangeCertificate{ RoundChangeMessages: roundChangeMessages, }, @@ -710,10 +705,14 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, Certificate: &proto.RoundChangeCertificate{ - RoundChangeMessages: generateFilledRCMessages(quorum, proposal, proposalHash), + RoundChangeMessages: generateFilledRCMessages( + quorum, + correctRoundMessage.proposal, + correctRoundMessage.hash, + ), }, }, }, @@ -768,7 +767,7 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -824,10 +823,10 @@ func TestRunNewRound_Validator_NonZero(t *testing.T) { assert.Equal(t, prepare, i.state.name) // Make sure the accepted proposal is the one that was sent out - assert.Equal(t, proposal, i.state.getProposal()) + assert.Equal(t, correctRoundMessage.proposal, i.state.getProposal()) // Make sure the correct proposal hash was multicasted - assert.True(t, prepareHashMatches(proposalHash, multicastedPrepare)) + assert.True(t, prepareHashMatches(correctRoundMessage.hash, multicastedPrepare)) }) } } @@ -845,8 +844,6 @@ func TestRunPrepare(t *testing.T) { ctx, cancelFn := context.WithCancel(context.Background()) var ( - proposal = []byte("block proposal") - proposalHash = []byte("proposal hash") multicastedCommit *proto.Message = nil notifyCh = make(chan uint64, 1) @@ -863,7 +860,7 @@ func TestRunPrepare(t *testing.T) { Type: proto.MessageType_COMMIT, Payload: &proto.Message_CommitData{ CommitData: &proto.CommitMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -872,7 +869,7 @@ func TestRunPrepare(t *testing.T) { return len(messages) >= 1 }, isValidProposalHashFn: func(_ []byte, hash []byte) bool { - return bytes.Equal(proposalHash, hash) + return bytes.Equal(correctRoundMessage.hash, hash) }, } messages = mockMessages{ @@ -897,7 +894,7 @@ func TestRunPrepare(t *testing.T) { Type: proto.MessageType_PREPARE, Payload: &proto.Message_PrepareData{ PrepareData: &proto.PrepareMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, }, }, }, @@ -914,8 +911,8 @@ func TestRunPrepare(t *testing.T) { i.state.proposalMessage = &proto.Message{ Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -933,10 +930,10 @@ func TestRunPrepare(t *testing.T) { assert.Equal(t, commit, i.state.name) // Make sure the proposal didn't change - assert.Equal(t, proposal, i.state.getProposal()) + assert.Equal(t, correctRoundMessage.proposal, i.state.getProposal()) // Make sure the proper proposal hash was multicasted - assert.True(t, commitHashMatches(proposalHash, multicastedCommit)) + assert.True(t, commitHashMatches(correctRoundMessage.hash, multicastedCommit)) }, ) } @@ -954,8 +951,6 @@ func TestRunCommit(t *testing.T) { var ( wg sync.WaitGroup - proposal = []byte("block proposal") - proposalHash = []byte("proposal hash") signer = []byte("signer") insertedProposal []byte = nil insertedCommittedSeals []*messages.CommittedSeal = nil @@ -979,7 +974,7 @@ func TestRunCommit(t *testing.T) { return len(messages) >= 1 }, isValidProposalHashFn: func(_ []byte, hash []byte) bool { - return bytes.Equal(proposalHash, hash) + return bytes.Equal(correctRoundMessage.hash, hash) }, } messages = mockMessages{ @@ -1001,7 +996,7 @@ func TestRunCommit(t *testing.T) { Type: proto.MessageType_COMMIT, Payload: &proto.Message_CommitData{ CommitData: &proto.CommitMessage{ - ProposalHash: proposalHash, + ProposalHash: correctRoundMessage.hash, CommittedSeal: committedSeals[0].Signature, }, }, @@ -1019,8 +1014,8 @@ func TestRunCommit(t *testing.T) { i.state.proposalMessage = &proto.Message{ Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, }, }, } @@ -1057,7 +1052,7 @@ func TestRunCommit(t *testing.T) { assert.Equal(t, fin, i.state.name) // Make sure the inserted proposal was the one present - assert.Equal(t, insertedProposal, proposal) + assert.Equal(t, insertedProposal, correctRoundMessage.proposal) // Make sure the inserted committed seals were correct assert.Equal(t, insertedCommittedSeals, committedSeals) @@ -1272,13 +1267,11 @@ func TestIBFT_FutureProposal(t *testing.T) { nodeID := []byte("node ID") proposer := []byte("proposer") - proposal := []byte("proposal") - proposalHash := []byte("proposal hash") quorum := uint64(4) - generateEmptyRCMessages := func(count uint64) []*proto.Message { + generateEmptyRCMessages := func(count uint64, round uint64) []*proto.Message { // Generate random RC messages - roundChangeMessages := generateMessages(count, proto.MessageType_ROUND_CHANGE) + roundChangeMessages := generateMessagesWithUniqueSender(count, proto.MessageType_ROUND_CHANGE) // Fill up their certificates for _, message := range roundChangeMessages { @@ -1288,11 +1281,20 @@ func TestIBFT_FutureProposal(t *testing.T) { LatestPreparedCertificate: nil, }, } + + message.View.Round = round } return roundChangeMessages } + generateFilledRCMessagesWithRound := func(quorum, round uint64) []*proto.Message { + messages := generateFilledRCMessages(quorum, correctRoundMessage.proposal, correctRoundMessage.hash) + setRoundForMessages(messages, round) + + return messages + } + testTable := []struct { name string proposalView *proto.View @@ -1305,7 +1307,7 @@ func TestIBFT_FutureProposal(t *testing.T) { Height: 0, Round: 1, }, - generateEmptyRCMessages(quorum), + generateEmptyRCMessages(quorum, 1), 1, }, { @@ -1314,7 +1316,7 @@ func TestIBFT_FutureProposal(t *testing.T) { Height: 0, Round: 2, }, - generateFilledRCMessages(quorum, proposal, proposalHash), + generateFilledRCMessagesWithRound(quorum, 2), 2, }, } @@ -1333,8 +1335,8 @@ func TestIBFT_FutureProposal(t *testing.T) { Type: proto.MessageType_PREPREPARE, Payload: &proto.Message_PreprepareData{ PreprepareData: &proto.PrePrepareMessage{ - Proposal: proposal, - ProposalHash: proposalHash, + Proposal: correctRoundMessage.proposal, + ProposalHash: correctRoundMessage.hash, Certificate: &proto.RoundChangeCertificate{ RoundChangeMessages: testCase.roundChangeMessages, }, @@ -1356,7 +1358,8 @@ func TestIBFT_FutureProposal(t *testing.T) { return nodeID }, isValidProposalHashFn: func(p []byte, hash []byte) bool { - return bytes.Equal(hash, proposalHash) && bytes.Equal(p, proposal) + return bytes.Equal(hash, correctRoundMessage.hash) && + bytes.Equal(p, correctRoundMessage.proposal) }, hasQuorumFn: defaultHasQuorumFn(quorum), } @@ -1414,7 +1417,7 @@ func TestIBFT_FutureProposal(t *testing.T) { } assert.Equal(t, testCase.notifyRound, receivedProposalEvent.round) - assert.Equal(t, proposal, messages.ExtractProposal(receivedProposalEvent.proposalMessage)) + assert.Equal(t, correctRoundMessage.proposal, messages.ExtractProposal(receivedProposalEvent.proposalMessage)) }) } } @@ -1603,10 +1606,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1628,7 +1630,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit+1) @@ -1640,10 +1642,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1671,10 +1672,53 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, + ) + + setRoundForMessages(allMessages, rLimit-1) + + assert.False(t, i.validPC(certificate, rLimit, 0)) + }) + + t.Run("rounds are not the same", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + rLimit = uint64(2) + sender = []byte("unique node") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + hasQuorumFn: defaultHasQuorumFn(quorum), + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return !bytes.Equal(proposer, sender) + }, + } + ) + + i := NewIBFT(log, backend, transport) + + proposal := generateMessagesWithSender(1, proto.MessageType_PREPREPARE, sender)[0] + + certificate := &proto.PreparedCertificate{ + ProposalMessage: proposal, + PrepareMessages: generateMessagesWithUniqueSender(quorum-1, proto.MessageType_PREPARE), + } + + // Make sure they all have the same proposal hash + allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) + appendProposalHash( + allMessages, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) + // Make sure the round is invalid for some random message + randomIndex := rand.Intn(len(certificate.PrepareMessages)) + randomPrepareMessage := certificate.PrepareMessages[randomIndex] + randomPrepareMessage.View.Round = 0 assert.False(t, i.validPC(certificate, rLimit, 0)) }) @@ -1683,10 +1727,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1711,7 +1754,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) @@ -1723,10 +1766,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1755,7 +1797,89 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, + ) + + setRoundForMessages(allMessages, rLimit-1) + + assert.False(t, i.validPC(certificate, rLimit, 0)) + }) + + t.Run("proposal is from an invalid sender", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + hasQuorumFn: defaultHasQuorumFn(quorum), + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return bytes.Equal(proposer, sender) + }, + isValidSenderFn: func(message *proto.Message) bool { + // Proposer is invalid + return !bytes.Equal(message.From, sender) + }, + } + ) + + i := NewIBFT(log, backend, transport) + + proposal := generateMessagesWithSender(1, proto.MessageType_PREPREPARE, sender)[0] + + certificate := &proto.PreparedCertificate{ + ProposalMessage: proposal, + PrepareMessages: generateMessagesWithUniqueSender(quorum-1, proto.MessageType_PREPARE), + } + + // Make sure they all have the same proposal hash + allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) + appendProposalHash( + allMessages, + correctRoundMessage.hash, + ) + + setRoundForMessages(allMessages, rLimit-1) + + assert.False(t, i.validPC(certificate, rLimit, 0)) + }) + + t.Run("prepare from proposer", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + hasQuorumFn: defaultHasQuorumFn(quorum), + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return true + }, + } + ) + + i := NewIBFT(log, backend, transport) + + proposal := generateMessagesWithSender(1, proto.MessageType_PREPREPARE, sender)[0] + + certificate := &proto.PreparedCertificate{ + ProposalMessage: proposal, + PrepareMessages: generateMessagesWithUniqueSender(quorum-1, proto.MessageType_PREPARE), + } + + // Make sure they all have the same proposal hash + allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) + appendProposalHash( + allMessages, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) @@ -1767,10 +1891,9 @@ func TestIBFT_ValidPC(t *testing.T) { t.Parallel() var ( - quorum = uint64(4) - rLimit = uint64(1) - sender = []byte("unique node") - proposalHash = []byte("proposal hash") + quorum = uint64(4) + rLimit = uint64(1) + sender = []byte("unique node") log = mockLogger{} transport = mockTransport{} @@ -1798,7 +1921,7 @@ func TestIBFT_ValidPC(t *testing.T) { allMessages := append([]*proto.Message{certificate.ProposalMessage}, certificate.PrepareMessages...) appendProposalHash( allMessages, - proposalHash, + correctRoundMessage.hash, ) setRoundForMessages(allMessages, rLimit-1) @@ -1938,6 +2061,53 @@ func TestIBFT_ValidateProposal(t *testing.T) { assert.False(t, i.validateProposal(proposal, baseView)) }) + t.Run("non unique senders", func(t *testing.T) { + t.Parallel() + + var ( + quorum = uint64(4) + self = []byte("node id") + + log = mockLogger{} + transport = mockTransport{} + backend = mockBackend{ + idFn: func() []byte { + return self + }, + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return !bytes.Equal(proposer, self) + }, + } + ) + + i := NewIBFT(log, backend, transport) + + baseView := &proto.View{ + Height: 0, + Round: 0, + } + + // Make sure all rcc are from same node + messages := generateMessages(quorum, proto.MessageType_ROUND_CHANGE) + for _, msg := range messages { + msg.From = []byte("non unique node id") + } + + proposal := &proto.Message{ + View: baseView, + Type: proto.MessageType_PREPREPARE, + Payload: &proto.Message_PreprepareData{ + PreprepareData: &proto.PrePrepareMessage{ + Certificate: &proto.RoundChangeCertificate{ + RoundChangeMessages: messages, + }, + }, + }, + } + + assert.False(t, i.validateProposal(proposal, baseView)) + }) + t.Run("there are < quorum RC messages in the certificate", func(t *testing.T) { t.Parallel() @@ -2068,11 +2238,9 @@ func TestIBFT_WatchForFutureRCC(t *testing.T) { t.Parallel() quorum := uint64(4) - proposal := []byte("proposal") rccRound := uint64(10) - proposalHash := []byte("proposal hash") - roundChangeMessages := generateFilledRCMessages(quorum, proposal, proposalHash) + roundChangeMessages := generateFilledRCMessages(quorum, correctRoundMessage.proposal, correctRoundMessage.hash) setRoundForMessages(roundChangeMessages, rccRound) var ( @@ -2085,8 +2253,8 @@ func TestIBFT_WatchForFutureRCC(t *testing.T) { transport = mockTransport{} backend = mockBackend{ hasQuorumFn: defaultHasQuorumFn(quorum), - isProposerFn: func(_ []byte, _ uint64, _ uint64) bool { - return true + isProposerFn: func(proposer []byte, _ uint64, _ uint64) bool { + return bytes.Equal(proposer, []byte("unique node")) }, } messages = mockMessages{ @@ -2144,6 +2312,8 @@ func TestIBFT_WatchForFutureRCC(t *testing.T) { // TestState_String makes sure the string representation // of states is correct func TestState_String(t *testing.T) { + t.Parallel() + stringMap := map[stateType]string{ newRound: "new round", prepare: "prepare", @@ -2320,3 +2490,178 @@ func TestDummyForMutations(t *testing.T) { t.Run(name, test) } } + +func Test_getRoundTimeout(t *testing.T) { + t.Parallel() + + type args struct { + baseRoundTimeout time.Duration + additionalTimeout time.Duration + round uint64 + } + + tests := []struct { + name string + args args + want time.Duration + }{ + { + name: "first round duration", + args: args{ + baseRoundTimeout: time.Second, + additionalTimeout: time.Second, + round: 0, + }, + want: time.Second * 2, + }, + { + name: "zero round duration", + args: args{ + baseRoundTimeout: time.Second, + additionalTimeout: time.Second, + round: 1, + }, + want: time.Second * 3, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := getRoundTimeout(tt.args.baseRoundTimeout, tt.args.additionalTimeout, tt.args.round) + assert.Equalf(t, tt.want, got, "getRoundTimeout(%v, %v, %v)", tt.args.baseRoundTimeout, tt.args.additionalTimeout, tt.args.round) + }) + } +} + +func TestIBFT_AddMessage(t *testing.T) { + t.Parallel() + + const ( + validHeight = uint64(10) + validRound = uint64(7) + validMsgType = proto.MessageType_PREPREPARE + ) + + var validSender = []byte{1, 2, 3} + + executeTest := func( + msg *proto.Message, + hasQuorum, shouldAddMessageCalled, shouldHasQuorumCalled, shouldSignalEventCalled bool) { + var ( + hasQuorumCalled = false + signalEventCalled = false + addMessageCalled = false + log = mockLogger{} + backend = mockBackend{} + transport = mockTransport{} + messages = mockMessages{} + ) + + backend.isValidSenderFn = func(m *proto.Message) bool { + return bytes.Equal(m.From, validSender) + } + + backend.hasQuorumFn = func(height uint64, _ []*proto.Message, msgType proto.MessageType) bool { + hasQuorumCalled = true + + assert.Equal(t, validHeight, height) + assert.Equal(t, validMsgType, msgType) + + return hasQuorum + } + + messages.addMessageFn = func(m *proto.Message) { + addMessageCalled = true + + assert.Equal(t, msg, m) + } + + messages.signalEventFn = func(*proto.Message) { + signalEventCalled = true + } + + i := NewIBFT(log, backend, transport) + i.messages = messages + i.state.view = &proto.View{Height: validHeight, Round: validRound} + + i.AddMessage(msg) + + assert.Equal(t, shouldAddMessageCalled, addMessageCalled) + assert.Equal(t, shouldHasQuorumCalled, hasQuorumCalled) + assert.Equal(t, shouldSignalEventCalled, signalEventCalled) + } + + t.Run("nil message case", func(t *testing.T) { + t.Parallel() + + executeTest(nil, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid sender", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + View: &proto.View{Height: validHeight, Round: validRound}, + Type: validMsgType, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid view", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid height", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight - 1, Round: validRound}, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("!isAcceptableMessage - invalid round", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight, Round: validRound - 1}, + } + executeTest(msg, true, false, false, false) + }) + + t.Run("correct - but quorum not reached", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight, Round: validRound}, + } + executeTest(msg, false, true, true, false) + }) + + t.Run("correct - quorum reached", func(t *testing.T) { + t.Parallel() + + msg := &proto.Message{ + From: validSender, + Type: validMsgType, + View: &proto.View{Height: validHeight, Round: validRound}, + } + executeTest(msg, true, true, true, true) + }) +} diff --git a/core/mock_test.go b/core/mock_test.go index 2170a97..8e6b9e6 100644 --- a/core/mock_test.go +++ b/core/mock_test.go @@ -11,6 +11,24 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +const ( + testRoundTimeout = time.Second +) + +var ( + correctRoundMessage = roundMessage{ + proposal: []byte("proposal"), + hash: []byte("proposal hash"), + seal: []byte("seal"), + } + + badRoundMessage = roundMessage{ + proposal: []byte("bad proposal"), + hash: []byte("bad proposal hash"), + seal: []byte("bad seal"), + } +) + // Define delegation methods type isValidBlockDelegate func([]byte) bool type isValidSenderDelegate func(*proto.Message) bool @@ -216,6 +234,7 @@ func (l mockLogger) Error(msg string, args ...interface{}) { type mockMessages struct { addMessageFn func(message *proto.Message) pruneByHeightFn func(height uint64) + signalEventFn func(message *proto.Message) getValidMessagesFn func( view *proto.View, @@ -266,6 +285,12 @@ func (m mockMessages) PruneByHeight(height uint64) { } } +func (m mockMessages) SignalEvent(msg *proto.Message) { + if m.signalEventFn != nil { + m.signalEventFn(msg) + } +} + func (m mockMessages) GetMostRoundChangeMessages(round, height uint64) []*proto.Message { if m.getMostRoundChangeMessagesFn != nil { return m.getMostRoundChangeMessagesFn(round, height) @@ -281,14 +306,42 @@ type transportConfigCallback func(*mockTransport) // newMockCluster creates a new IBFT cluster func newMockCluster( numNodes uint64, - backendCallbackMap map[int]backendConfigCallback, - loggerCallbackMap map[int]loggerConfigCallback, - transportCallbackMap map[int]transportConfigCallback, + backendCallback func(*mockBackend, int), + loggerCallback func(*mockLogger, int), + transportCallback func(*mockTransport, int), ) *mockCluster { if numNodes < 1 { return nil } + // Initialize the backend and transport callbacks for + // each node in the arbitrary cluster + backendCallbackMap := make(map[int]backendConfigCallback) + loggerCallbackMap := make(map[int]loggerConfigCallback) + transportCallbackMap := make(map[int]transportConfigCallback) + + for i := 0; i < int(numNodes); i++ { + i := i + + if backendCallback != nil { + backendCallbackMap[i] = func(backend *mockBackend) { + backendCallback(backend, i) + } + } + + if transportCallback != nil { + transportCallbackMap[i] = func(backend *mockTransport) { + transportCallback(backend, i) + } + } + + if loggerCallback != nil { + loggerCallbackMap[i] = func(backend *mockLogger) { + loggerCallback(backend, i) + } + } + } + nodes := make([]*IBFT, numNodes) nodeCtxs := make([]mockNodeContext, numNodes) @@ -300,21 +353,21 @@ func newMockCluster( ) // Execute set callbacks, if any - if backendCallbackMap != nil { - if backendCallback, isSet := backendCallbackMap[index]; isSet { - backendCallback(backend) + if len(backendCallbackMap) > 0 { + if bc, isSet := backendCallbackMap[index]; isSet { + bc(backend) } } - if loggerCallbackMap != nil { - if loggerCallback, isSet := loggerCallbackMap[index]; isSet { - loggerCallback(logger) + if len(loggerCallbackMap) > 0 { + if lc, isSet := loggerCallbackMap[index]; isSet { + lc(logger) } } - if transportCallbackMap != nil { - if transportCallback, isSet := transportCallbackMap[index]; isSet { - transportCallback(transport) + if len(transportCallbackMap) > 0 { + if tc, isSet := transportCallbackMap[index]; isSet { + tc(transport) } } @@ -322,17 +375,19 @@ func newMockCluster( nodes[index] = NewIBFT(logger, backend, transport) // Instantiate context for the nodes - ctx, cancelFn := context.WithCancel(context.Background()) - nodeCtxs[index] = mockNodeContext{ - ctx: ctx, - cancelFn: cancelFn, - } + nodeCtxs[index] = newMockNodeContext() } - return &mockCluster{ + cr := &mockCluster{ nodes: nodes, ctxs: nodeCtxs, } + + // Set a small timeout, because of situations + // where the byzantine node is the proposer + cr.setBaseTimeout(testRoundTimeout) + + return cr } // mockNodeContext keeps track of the node runtime context @@ -341,6 +396,16 @@ type mockNodeContext struct { cancelFn context.CancelFunc } +// newMockNodeContext is the constructor of mockNodeContext +func newMockNodeContext() mockNodeContext { + ctx, cancelFn := context.WithCancel(context.Background()) + + return mockNodeContext{ + ctx: ctx, + cancelFn: cancelFn, + } +} + // mockNodeWg is the WaitGroup wrapper for the cluster nodes type mockNodeWg struct { sync.WaitGroup @@ -352,8 +417,8 @@ func (wg *mockNodeWg) Add(delta int) { } func (wg *mockNodeWg) Done() { - wg.WaitGroup.Done() atomic.AddInt64(&wg.count, 1) + wg.WaitGroup.Done() } func (wg *mockNodeWg) getDone() int64 { @@ -378,18 +443,12 @@ func (m *mockCluster) runSequence(height uint64) { for nodeIndex, node := range m.nodes { m.wg.Add(1) - go func( - ctx context.Context, - node *IBFT, - height uint64, - ) { - defer func() { - m.wg.Done() - }() - + go func(ctx context.Context, node *IBFT) { // Start the main run loop for the node node.RunSequence(ctx, height) - }(m.ctxs[nodeIndex].ctx, node, height) + + m.wg.Done() + }(m.ctxs[nodeIndex].ctx, node) } } @@ -405,8 +464,10 @@ func (m *mockCluster) awaitCompletion() { // in the cluster, and awaits their completion func (m *mockCluster) forceShutdown() { // Send a stop signal to all the nodes - for _, ctx := range m.ctxs { + for i, ctx := range m.ctxs { ctx.cancelFn() + + m.ctxs[i] = newMockNodeContext() } // Wait for all the nodes to finish diff --git a/core/rapid_test.go b/core/rapid_test.go index 0a5d24f..6adb2e0 100644 --- a/core/rapid_test.go +++ b/core/rapid_test.go @@ -6,7 +6,6 @@ import ( "fmt" "sync" "testing" - "time" "github.com/stretchr/testify/assert" "pgregory.net/rapid" @@ -15,6 +14,13 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +// roundMessage contains message data within consensus round +type roundMessage struct { + proposal []byte + seal []byte + hash []byte +} + // mockInsertedProposals keeps track of inserted proposals for a cluster // of nodes type mockInsertedProposals struct { @@ -45,215 +51,180 @@ func (m *mockInsertedProposals) insertProposal( proposal []byte, ) { m.Lock() - defer m.Unlock() - m.proposals[nodeIndex][m.currentProposals[nodeIndex]] = proposal m.currentProposals[nodeIndex]++ + m.Unlock() } -// TestProperty_AllHonestNodes is a property-based test -// that assures the cluster can reach consensus on any -// arbitrary number of valid nodes -func TestProperty_AllHonestNodes(t *testing.T) { - t.Parallel() - - rapid.Check(t, func(t *rapid.T) { - var multicastFn func(message *proto.Message) +// getProposer returns proposer index +func getProposer(height, round, nodes uint64) uint64 { + return (height + round) % nodes +} - var ( - proposal = []byte("proposal") - proposalHash = []byte("proposal hash") - committedSeal = []byte("seal") +// propertyTestEvent is the behaviour setup per specific round +type propertyTestEvent struct { + // silentByzantineNodes is the number of byzantine nodes + // that are going to be silent, i.e. do not respond + silentByzantineNodes uint64 - numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") - desiredHeight = rapid.Uint64Range(10, 20).Draw(t, "minimum height to be reached") + // badByzantineNodes is the number of byzantine nodes + // that are going to send bad messages + badByzantineNodes uint64 +} - nodes = generateNodeAddresses(numNodes) - insertedProposals = newMockInsertedProposals(numNodes) - ) - // commonTransportCallback is the common method modification - // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { - transport.multicastFn = func(message *proto.Message) { - multicastFn(message) - } - } +func (e propertyTestEvent) badNodes() uint64 { + return e.silentByzantineNodes + e.badByzantineNodes +} - // commonBackendCallback is the common method modification required - // for the Backend, for all nodes - commonBackendCallback := func(backend *mockBackend, nodeIndex int) { - // Make sure the quorum function requires all nodes - backend.hasQuorumFn = commonHasQuorumFn(numNodes) +func (e propertyTestEvent) isSilent(nodeIndex int) bool { + return uint64(nodeIndex) < e.silentByzantineNodes +} - // Make sure the node ID is properly relayed - backend.idFn = func() []byte { - return nodes[nodeIndex] - } +// getMessage returns bad message for byzantine bad node, +// correct message for non-byzantine nodes, and nil for silent nodes +func (e propertyTestEvent) getMessage(nodeIndex int) *roundMessage { + message := correctRoundMessage + if uint64(nodeIndex) < e.badNodes() { + message = badRoundMessage + } - // Make sure the only proposer is picked using Round Robin - backend.isProposerFn = func(from []byte, height uint64, _ uint64) bool { - return bytes.Equal(from, nodes[height%numNodes]) - } + return &message +} - // Make sure the proposal is valid if it matches what node 0 proposed - backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, proposal) - } +// propertyTestSetup contains randomly-generated data for rapid testing +type propertyTestSetup struct { + sync.Mutex - // Make sure the proposal hash matches - backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, proposal) && bytes.Equal(ph, proposalHash) - } + // nodes is the total number of nodes + nodes uint64 - // Make sure the preprepare message is built correctly - backend.buildPrePrepareMessageFn = func( - proposal []byte, - certificate *proto.RoundChangeCertificate, - view *proto.View, - ) *proto.Message { - return buildBasicPreprepareMessage( - proposal, - proposalHash, - certificate, - nodes[nodeIndex], - view) - } + // desiredHeight is the desired height number + desiredHeight uint64 - // Make sure the prepare message is built correctly - backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(proposalHash, nodes[nodeIndex], view) - } + // events is the mapping between the current height and its rounds + events [][]propertyTestEvent - // Make sure the commit message is built correctly - backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage(proposalHash, committedSeal, nodes[nodeIndex], view) - } + currentHeight map[int]uint64 + currentRound map[int]uint64 +} - // Make sure the round change message is built correctly - backend.buildRoundChangeMessageFn = func( - proposal []byte, - certificate *proto.PreparedCertificate, - view *proto.View, - ) *proto.Message { - return buildBasicRoundChangeMessage(proposal, certificate, view, nodes[nodeIndex]) - } +func (s *propertyTestSetup) setRound(nodeIndex int, round uint64) { + s.Lock() + s.currentRound[nodeIndex] = round + s.Unlock() +} - // Make sure the inserted proposal is noted - backend.insertBlockFn = func(proposal []byte, _ []*messages.CommittedSeal) { - insertedProposals.insertProposal(nodeIndex, proposal) - } +func (s *propertyTestSetup) incHeight() { + s.Lock() - // Make sure the proposal can be built - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposal - } - } + for nodeIndex := 0; uint64(nodeIndex) < s.nodes; nodeIndex++ { + s.currentHeight[nodeIndex]++ + s.currentRound[nodeIndex] = 0 + } - // Initialize the backend and transport callbacks for - // each node in the arbitrary cluster - backendCallbackMap := make(map[int]backendConfigCallback) - transportCallbackMap := make(map[int]transportConfigCallback) + s.Unlock() +} - for i := 0; i < int(numNodes); i++ { - i := i - backendCallbackMap[i] = func(backend *mockBackend) { - commonBackendCallback(backend, i) - } +func (s *propertyTestSetup) getEvent(nodeIndex int) propertyTestEvent { + s.Lock() - transportCallbackMap[i] = commonTransportCallback - } + var ( + height = int(s.currentHeight[nodeIndex]) + roundNumber = int(s.currentRound[nodeIndex]) + round propertyTestEvent + ) - // Create the mock cluster - cluster := newMockCluster( - numNodes, - backendCallbackMap, - nil, - transportCallbackMap, - ) + if roundNumber >= len(s.events[height]) { + round = s.events[height][len(s.events[height])-1] + } else { + round = s.events[height][roundNumber] + } - // Set the multicast callback to relay the message - // to the entire cluster - multicastFn = func(message *proto.Message) { - cluster.pushMessage(message) - } + s.Unlock() - // Run the sequence up until a certain height - for height := uint64(0); height < desiredHeight; height++ { - // Start the main run loops - cluster.runSequence(height) + return round +} - // Wait until the main run loops finish - cluster.awaitCompletion() - } +func (s *propertyTestSetup) lastRound(height uint64) propertyTestEvent { + return s.events[height][len(s.events[height])-1] +} - // Make sure that the inserted proposal is valid for each height - for _, proposalMap := range insertedProposals.proposals { - // Make sure the node has the adequate number of inserted proposals - assert.Len(t, proposalMap, int(desiredHeight)) +// generatePropertyTestEvent generates propertyTestEvent model +func generatePropertyTestEvent(t *rapid.T) *propertyTestSetup { + // Generate random setup of the nodes number, byzantine nodes number, and desired height + var ( + numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") + desiredHeight = rapid.Uint64Range(5, 20).Draw(t, "minimum height to be reached") + maxBadNodes = maxFaulty(numNodes) + ) + + setup := &propertyTestSetup{ + nodes: numNodes, + desiredHeight: desiredHeight, + events: make([][]propertyTestEvent, desiredHeight), + currentHeight: map[int]uint64{}, + currentRound: map[int]uint64{}, + } - for _, insertedProposal := range proposalMap { - assert.True(t, bytes.Equal(proposal, insertedProposal)) + // Go over the desired height and generate random number of rounds + // depending on the round result: success or fail. + for height := uint64(0); height < desiredHeight; height++ { + var round uint64 + + // Generate random rounds until we reach a state where to expect a successfully + // met consensus. Meaning >= 2/3 of all nodes would reach the consensus. + for { + numByzantineNodes := rapid. + Uint64Range(0, maxBadNodes). + Draw(t, fmt.Sprintf("number of byzantine nodes for height %d on round %d", height, round)) + silentByzantineNodes := rapid. + Uint64Range(0, numByzantineNodes). + Draw(t, fmt.Sprintf("number of silent byzantine nodes for height %d on round %d", height, round)) + proposerIdx := getProposer(height, round, numNodes) + + setup.events[height] = append(setup.events[height], propertyTestEvent{ + silentByzantineNodes: silentByzantineNodes, + badByzantineNodes: numByzantineNodes - silentByzantineNodes, + }) + + // If the proposer per the current round is not byzantine node, + // it is expected the consensus should be met, so the loop + // could be stopped for the running height. + if proposerIdx >= numByzantineNodes { + break } - } - }) -} -// getByzantineNodes returns a random subset of -// byzantine nodes -func getByzantineNodes( - numNodes uint64, - set [][]byte, -) map[string]struct{} { - gen := rapid.SampledFrom(set) - byzantineNodes := make(map[string]struct{}) - - for i := 0; i < int(numNodes); i++ { - byzantineNodes[string(gen.Example(i))] = struct{}{} + round++ + } } - return byzantineNodes + return setup } -// TestProperty_MajorityHonestNodes is a property-based test -// that assures the cluster can reach consensus on any -// arbitrary number of valid nodes and byzantine nodes -func TestProperty_MajorityHonestNodes(t *testing.T) { +// TestProperty is a property-based test +// that assures the cluster can handle rounds properly in any cases. +func TestProperty(t *testing.T) { t.Parallel() rapid.Check(t, func(t *rapid.T) { var multicastFn func(message *proto.Message) var ( - proposal = []byte("proposal") - proposalHash = []byte("proposal hash") - committedSeal = []byte("seal") - - numNodes = rapid.Uint64Range(4, 30).Draw(t, "number of cluster nodes") - numByzantineNodes = rapid.Uint64Range(1, maxFaulty(numNodes)).Draw(t, "number of byzantine nodes") - desiredHeight = rapid.Uint64Range(1, 5).Draw(t, "minimum height to be reached") - - nodes = generateNodeAddresses(numNodes) - insertedProposals = newMockInsertedProposals(numNodes) + setup = generatePropertyTestEvent(t) + nodes = generateNodeAddresses(setup.nodes) + insertedProposals = newMockInsertedProposals(setup.nodes) ) - // Initialize the byzantine nodes - byzantineNodes := getByzantineNodes( - numByzantineNodes, - nodes, - ) - - isByzantineNode := func(from []byte) bool { - _, exists := byzantineNodes[string(from)] - - return exists - } // commonTransportCallback is the common method modification // required for Transport, for all nodes - commonTransportCallback := func(transport *mockTransport) { + commonTransportCallback := func(transport *mockTransport, nodeIndex int) { transport.multicastFn = func(message *proto.Message) { - if isByzantineNode(message.From) { - // If the node is byzantine, mock - // not sending out the message + if message.Type == proto.MessageType_ROUND_CHANGE { + setup.setRound(nodeIndex, message.View.Round) + } + + // If node is silent, don't send a message + if setup.getEvent(nodeIndex).isSilent(nodeIndex) { return } @@ -265,7 +236,7 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // for the Backend, for all nodes commonBackendCallback := func(backend *mockBackend, nodeIndex int) { // Make sure the quorum function is Quorum optimal - backend.hasQuorumFn = commonHasQuorumFn(numNodes) + backend.hasQuorumFn = commonHasQuorumFn(setup.nodes) // Make sure the node ID is properly relayed backend.idFn = func() []byte { @@ -273,21 +244,25 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { } // Make sure the only proposer is picked using Round Robin - backend.isProposerFn = func(from []byte, height uint64, round uint64) bool { + backend.isProposerFn = func(from []byte, height, round uint64) bool { return bytes.Equal( from, - nodes[int(height+round)%len(nodes)], + nodes[getProposer(height, round, setup.nodes)], ) } // Make sure the proposal is valid if it matches what node 0 proposed backend.isValidBlockFn = func(newProposal []byte) bool { - return bytes.Equal(newProposal, proposal) + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + + return bytes.Equal(newProposal, message.proposal) } // Make sure the proposal hash matches backend.isValidProposalHashFn = func(p []byte, ph []byte) bool { - return bytes.Equal(p, proposal) && bytes.Equal(ph, proposalHash) + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + + return bytes.Equal(p, message.proposal) && bytes.Equal(ph, message.hash) } // Make sure the preprepare message is built correctly @@ -296,9 +271,11 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { certificate *proto.RoundChangeCertificate, view *proto.View, ) *proto.Message { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + return buildBasicPreprepareMessage( proposal, - proposalHash, + message.hash, certificate, nodes[nodeIndex], view, @@ -307,12 +284,16 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { // Make sure the prepare message is built correctly backend.buildPrepareMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicPrepareMessage(proposalHash, nodes[nodeIndex], view) + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + + return buildBasicPrepareMessage(message.hash, nodes[nodeIndex], view) } // Make sure the commit message is built correctly backend.buildCommitMessageFn = func(proposal []byte, view *proto.View) *proto.Message { - return buildBasicCommitMessage(proposalHash, committedSeal, nodes[nodeIndex], view) + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) + + return buildBasicCommitMessage(message.hash, message.seal, nodes[nodeIndex], view) } // Make sure the round change message is built correctly @@ -330,69 +311,74 @@ func TestProperty_MajorityHonestNodes(t *testing.T) { } // Make sure the proposal can be built - backend.buildProposalFn = func(_ *proto.View) []byte { - return proposal - } - } + backend.buildProposalFn = func(view *proto.View) []byte { + message := setup.getEvent(nodeIndex).getMessage(nodeIndex) - // Initialize the backend and transport callbacks for - // each node in the arbitrary cluster - backendCallbackMap := make(map[int]backendConfigCallback) - transportCallbackMap := make(map[int]transportConfigCallback) - - for i := 0; i < int(numNodes); i++ { - i := i - backendCallbackMap[i] = func(backend *mockBackend) { - commonBackendCallback(backend, i) + return message.proposal } - - transportCallbackMap[i] = commonTransportCallback } - // Create the mock cluster + // Create default cluster for rapid tests cluster := newMockCluster( - numNodes, - backendCallbackMap, + setup.nodes, + commonBackendCallback, nil, - transportCallbackMap, + commonTransportCallback, ) - // Set a small timeout, because of situations - // where the byzantine node is the proposer - cluster.setBaseTimeout(time.Second * 2) - // Set the multicast callback to relay the message // to the entire cluster - multicastFn = func(message *proto.Message) { - cluster.pushMessage(message) - } + multicastFn = cluster.pushMessage // Run the sequence up until a certain height - for height := uint64(0); height < desiredHeight; height++ { + for height := uint64(0); height < setup.desiredHeight; height++ { + // Create context timeout based on the bad nodes number + rounds := uint64(len(setup.events[height])) + ctxTimeout := getRoundTimeout(testRoundTimeout, testRoundTimeout, rounds*2) + // Start the main run loops cluster.runSequence(height) - // Wait until Quorum nodes finish their run loop - ctx, cancelFn := context.WithTimeout(context.Background(), time.Second*5) - if err := cluster.awaitNCompletions(ctx, int64(quorum(numNodes))); err != nil { - t.Fatalf( - fmt.Sprintf( - "unable to wait for nodes to complete, %v", - err, - ), - ) - } + ctx, cancelFn := context.WithTimeout(context.Background(), ctxTimeout) + err := cluster.awaitNCompletions(ctx, int64(quorum(setup.nodes))) + assert.NoError(t, err, "unable to wait for nodes to complete on height %d", height) + cancelFn() // Shutdown the remaining nodes that might be hanging cluster.forceShutdown() - cancelFn() - } - // Make sure that the inserted proposal is valid for each height - for _, proposalMap := range insertedProposals.proposals { - for _, insertedProposal := range proposalMap { - assert.True(t, bytes.Equal(proposal, insertedProposal)) + // Increment current height + setup.incHeight() + + // Make sure proposals map is not empty + assert.Len(t, insertedProposals.proposals, int(setup.nodes)) + + // Make sure bad nodes were out of the last round. + // Make sure we have inserted blocks >= quorum per round. + lastRound := setup.lastRound(height) + badNodes := lastRound.badNodes() + var proposalsNumber int + for nodeID, proposalMap := range insertedProposals.proposals { + if nodeID >= int(badNodes) { + // Only one inserted block per valid round + assert.LessOrEqual(t, len(proposalMap), 1) + proposalsNumber++ + + // Make sure inserted block value is correct + for _, val := range proposalMap { + assert.Equal(t, correctRoundMessage.proposal, val) + } + } else { + // There should not be inserted blocks in bad nodes + assert.Empty(t, proposalMap) + } } + + // Make sure the total number of inserted blocks >= quorum + assert.GreaterOrEqual(t, proposalsNumber, int(quorum(setup.nodes))) + + // Reset proposals map for the next height + insertedProposals = newMockInsertedProposals(setup.nodes) } }) } diff --git a/messages/event_manager.go b/messages/event_manager.go index 9edba60..7107c0f 100644 --- a/messages/event_manager.go +++ b/messages/event_manager.go @@ -4,8 +4,9 @@ import ( "sync" "sync/atomic" - "github.com/0xPolygon/go-ibft/messages/proto" "github.com/google/uuid" + + "github.com/0xPolygon/go-ibft/messages/proto" ) type eventManager struct { @@ -99,8 +100,9 @@ func (em *eventManager) close() { em.subscriptionsLock.Lock() defer em.subscriptionsLock.Unlock() - for _, subscription := range em.subscriptions { + for id, subscription := range em.subscriptions { subscription.close() + delete(em.subscriptions, id) } atomic.StoreInt64(&em.numSubscriptions, 0) diff --git a/messages/event_manager_test.go b/messages/event_manager_test.go index ce0d12a..909c326 100644 --- a/messages/event_manager_test.go +++ b/messages/event_manager_test.go @@ -1,18 +1,224 @@ package messages import ( + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/0xPolygon/go-ibft/messages/proto" ) +func TestEventManager_signalEvent(t *testing.T) { + t.Parallel() + + var ( + baseDetails = SubscriptionDetails{ + MessageType: proto.MessageType_PREPARE, + View: &proto.View{ + Height: 0, + Round: 0, + }, + MinNumMessages: 1, + } + + baseEventType = baseDetails.MessageType + baseEventView = &proto.View{ + Height: baseDetails.View.Height, + Round: baseDetails.View.Round, + } + ) + + // setupEventManagerAndSubscription creates new eventManager and a subscription + setupEventManagerAndSubscription := func(t *testing.T) (*eventManager, *Subscription) { + t.Helper() + + em := newEventManager() + + t.Cleanup(func() { + em.close() + }) + + subscription := em.subscribe(baseDetails) + + t.Cleanup(func() { + em.cancelSubscription(subscription.ID) + }) + + return em, subscription + } + + // emitEvent sends a event to eventManager and close doneCh after signalEvent completes + emitEvent := func( + t *testing.T, + em *eventManager, + eventType proto.MessageType, + eventView *proto.View, + ) <-chan struct{} { + t.Helper() + + doneCh := make(chan struct{}) + + go func() { + t.Helper() + + defer close(doneCh) + + em.signalEvent( + eventType, + eventView, + ) + }() + + return doneCh + } + + // testSubscriptionData checks the data sent to subscription + testSubscriptionData := func( + t *testing.T, + sub *Subscription, + expectedSignals []uint64, + ) <-chan struct{} { + t.Helper() + + doneCh := make(chan struct{}) + + go func() { + t.Helper() + + defer close(doneCh) + + actualSignals := make([]uint64, 0) + for sig := range sub.SubCh { + actualSignals = append(actualSignals, sig) + } + + assert.Equal(t, expectedSignals, actualSignals) + }() + + return doneCh + } + + // closeSubscription closes subscription manually + // because cancelSubscription might be unable + // due to mutex locking during tests + closeSubscription := func( + t *testing.T, + em *eventManager, + sub *Subscription, + ) { + t.Helper() + + close(em.subscriptions[sub.ID].doneCh) + delete(em.subscriptions, sub.ID) + } + + t.Run("should exit before locking subscriptionsLock if numSubscriptions is zero", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + // overwrite numSubscription + atomic.StoreInt64(&em.numSubscriptions, 0) + // shouldn't be locked by mutex thanks to early return + em.subscriptionsLock.Lock() + t.Cleanup(func() { + em.subscriptionsLock.Unlock() + }) + + doneEmitCh := emitEvent(t, em, baseEventType, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{}) + + // should exit by early return + select { + case <-doneEmitCh: + case <-time.After(5 * time.Second): + t.Errorf("signalEvent shouldn't be lock, but it was locked") + + return + } + + closeSubscription(t, em, sub) + <-doneTestSubCh + }) + + t.Run("should be locked by other write lock", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + // should be locked by other write lock + em.subscriptionsLock.Lock() + t.Cleanup(func() { + em.subscriptionsLock.Unlock() + }) + + doneCh := emitEvent(t, em, baseEventType, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{}) + + select { + case <-doneCh: + t.Errorf("signalEvent is not locked") + + return + case <-time.After(5 * time.Second): + } + + closeSubscription(t, em, sub) + <-doneTestSubCh + }) + + t.Run("should not be locked by other read lock", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + // shouldn't be locked by mutex of read-lock + em.subscriptionsLock.RLock() + t.Cleanup(func() { + em.subscriptionsLock.RUnlock() + }) + + doneCh := emitEvent(t, em, baseEventType, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{0}) + + select { + case <-doneCh: + case <-time.After(5 * time.Second): + t.Errorf("signalEvent is locked") + + return + } + + closeSubscription(t, em, sub) + <-doneTestSubCh + }) + + t.Run("should not notify if the event is different the one expected by subscription", func(t *testing.T) { + t.Parallel() + + em, sub := setupEventManagerAndSubscription(t) + + doneCh := emitEvent(t, em, proto.MessageType_COMMIT, baseEventView) + doneTestSubCh := testSubscriptionData(t, sub, []uint64{}) + + select { + case <-doneCh: + case <-time.After(5 * time.Second): + t.Errorf("signalEvent is locked") + + return + } + + closeSubscription(t, em, sub) + <-doneTestSubCh + }) +} + func TestEventManager_SubscribeCancel(t *testing.T) { t.Parallel() - numSubscriptions := 10 - subscriptions := make([]*Subscription, numSubscriptions) baseDetails := SubscriptionDetails{ MessageType: proto.MessageType_PREPARE, View: &proto.View{ @@ -22,7 +228,10 @@ func TestEventManager_SubscribeCancel(t *testing.T) { MinNumMessages: 1, } - IDMap := make(map[SubscriptionID]bool) + numSubscriptions := 10 + subscriptions := make([]*Subscription, numSubscriptions) + + idMap := make(map[SubscriptionID]bool) em := newEventManager() defer em.close() @@ -35,10 +244,10 @@ func TestEventManager_SubscribeCancel(t *testing.T) { assert.Equal(t, int64(i+1), em.numSubscriptions) // Check if a duplicate ID has been issued - if _, ok := IDMap[subscriptions[i].ID]; ok { + if _, ok := idMap[subscriptions[i].ID]; ok { t.Fatalf("Duplicate ID entry") } else { - IDMap[subscriptions[i].ID] = true + idMap[subscriptions[i].ID] = true } } @@ -71,8 +280,6 @@ func TestEventManager_SubscribeCancel(t *testing.T) { func TestEventManager_SubscribeClose(t *testing.T) { t.Parallel() - numSubscriptions := 10 - subscriptions := make([]*Subscription, numSubscriptions) baseDetails := SubscriptionDetails{ MessageType: proto.MessageType_PREPARE, View: &proto.View{ @@ -82,6 +289,9 @@ func TestEventManager_SubscribeClose(t *testing.T) { MinNumMessages: 1, } + numSubscriptions := 10 + subscriptions := make([]*Subscription, numSubscriptions) + em := newEventManager() // Create the subscriptions diff --git a/messages/helpers.go b/messages/helpers.go index cc11693..d30d621 100644 --- a/messages/helpers.go +++ b/messages/helpers.go @@ -6,6 +6,7 @@ import ( "github.com/0xPolygon/go-ibft/messages/proto" ) +// CommittedSeal Validator proof of signing a committed block type CommittedSeal struct { Signer []byte Signature []byte @@ -139,7 +140,7 @@ func HaveSameProposalHash(messages []*proto.Message) bool { return false } - var hash []byte = nil + var hash []byte for _, message := range messages { var extractedHash []byte @@ -149,6 +150,10 @@ func HaveSameProposalHash(messages []*proto.Message) bool { extractedHash = ExtractProposalHash(message) case proto.MessageType_PREPARE: extractedHash = ExtractPrepareHash(message) + case proto.MessageType_COMMIT: + return false + case proto.MessageType_ROUND_CHANGE: + return false default: return false } @@ -183,6 +188,23 @@ func AllHaveLowerRound(messages []*proto.Message, round uint64) bool { return true } +// AllHaveSameRound checks if all messages have the same round +func AllHaveSameRound(messages []*proto.Message) bool { + if len(messages) < 1 { + return false + } + + var round = messages[0].View.Round + + for _, message := range messages { + if message.View.Round != round { + return false + } + } + + return true +} + // AllHaveSameHeight checks if all messages have the same height func AllHaveSameHeight(messages []*proto.Message, height uint64) bool { if len(messages) < 1 { diff --git a/messages/helpers_test.go b/messages/helpers_test.go index 979cc1c..527df3e 100644 --- a/messages/helpers_test.go +++ b/messages/helpers_test.go @@ -11,37 +11,52 @@ import ( func TestMessages_ExtractCommittedSeals(t *testing.T) { t.Parallel() - signer := []byte("signer") - committedSeal := []byte("committed seal") - - commitMessage := &proto.Message{ - Type: proto.MessageType_COMMIT, - Payload: &proto.Message_CommitData{ - CommitData: &proto.CommitMessage{ - CommittedSeal: committedSeal, + newCommitMessage := func(from, committedSeal []byte) *proto.Message { + return &proto.Message{ + Type: proto.MessageType_COMMIT, + Payload: &proto.Message_CommitData{ + CommitData: &proto.CommitMessage{ + CommittedSeal: committedSeal, + }, }, - }, - From: signer, + From: from, + } } - invalidMessage := &proto.Message{ - Type: proto.MessageType_PREPARE, + + newInvalidMessage := func() *proto.Message { + return &proto.Message{ + Type: proto.MessageType_PREPARE, + } } - seals := ExtractCommittedSeals([]*proto.Message{ - commitMessage, - invalidMessage, - }) + var ( + signer1 = []byte("signer 1") + committedSeal1 = []byte("committed seal 1") - if len(seals) != 1 { - t.Fatalf("Seals not extracted") - } + signer2 = []byte("signer 2") + committedSeal2 = []byte("committed seal 2") + ) - expected := &CommittedSeal{ - Signer: signer, - Signature: committedSeal, - } + seals := ExtractCommittedSeals([]*proto.Message{ + newCommitMessage(signer1, committedSeal1), + newInvalidMessage(), + newCommitMessage(signer2, committedSeal2), + }) - assert.Equal(t, expected, seals[0]) + assert.Equal( + t, + []*CommittedSeal{ + { + Signer: signer1, + Signature: committedSeal1, + }, + { + Signer: signer2, + Signature: committedSeal2, + }, + }, + seals, + ) } func TestMessages_ExtractCommitHash(t *testing.T) { @@ -384,6 +399,15 @@ func TestMessages_HasUniqueSenders(t *testing.T) { nil, false, }, + { + "only one message", + []*proto.Message{ + { + From: []byte("node 1"), + }, + }, + true, + }, { "non unique senders", []*proto.Message{ @@ -440,9 +464,31 @@ func TestMessages_HaveSameProposalHash(t *testing.T) { nil, false, }, + { + "only one message", + []*proto.Message{ + { + Type: proto.MessageType_PREPARE, + Payload: &proto.Message_PrepareData{ + PrepareData: &proto.PrepareMessage{ + ProposalHash: []byte("hash"), + }, + }, + }, + }, + true, + }, { "invalid message type", []*proto.Message{ + { + Type: proto.MessageType_PREPARE, + Payload: &proto.Message_PrepareData{ + PrepareData: &proto.PrepareMessage{ + ProposalHash: []byte("differing hash"), + }, + }, + }, { Type: proto.MessageType_ROUND_CHANGE, }, @@ -471,6 +517,20 @@ func TestMessages_HaveSameProposalHash(t *testing.T) { }, false, }, + { + "only one message", + []*proto.Message{ + { + Type: proto.MessageType_PREPREPARE, + Payload: &proto.Message_PreprepareData{ + PreprepareData: &proto.PrePrepareMessage{ + ProposalHash: proposalHash, + }, + }, + }, + }, + true, + }, { "hash match", []*proto.Message{ @@ -528,7 +588,20 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { false, }, { - "not same lower round", + "true if message's round is less than threshold", + []*proto.Message{ + { + View: &proto.View{ + Height: 0, + Round: round - 1, + }, + }, + }, + round, + true, + }, + { + "false if message's round equals to threshold", []*proto.Message{ { View: &proto.View{ @@ -536,6 +609,13 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { Round: round, }, }, + }, + round, + false, + }, + { + "false if message's round is bigger than threshold", + []*proto.Message{ { View: &proto.View{ Height: 0, @@ -546,6 +626,25 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { round, false, }, + { + "some of messages are not higher round", + []*proto.Message{ + { + View: &proto.View{ + Height: 0, + Round: round + 1, + }, + }, + { + View: &proto.View{ + Height: 0, + Round: round, + }, + }, + }, + round, + false, + }, { "same higher round", []*proto.Message{ @@ -566,20 +665,33 @@ func TestMessages_AllHaveLowerRond(t *testing.T) { false, }, { - "lower round match", + "1 message is lower round", []*proto.Message{ { View: &proto.View{ - Height: 0, + Height: 1, Round: round, }, }, + }, + 2, + true, + }, + { + "all of messages is lower round", + []*proto.Message{ { View: &proto.View{ Height: 0, Round: round, }, }, + { + View: &proto.View{ + Height: 1, + Round: round, + }, + }, }, 2, true, @@ -620,7 +732,40 @@ func TestMessages_AllHaveSameHeight(t *testing.T) { false, }, { - "not same height", + "false if message's height is less than the given height", + []*proto.Message{ + { + View: &proto.View{ + Height: height - 1, + }, + }, + }, + false, + }, + { + "true if message's height equals to the given height", + []*proto.Message{ + { + View: &proto.View{ + Height: height, + }, + }, + }, + true, + }, + { + "false if message's height is bigger than the given height", + []*proto.Message{ + { + View: &proto.View{ + Height: height + 1, + }, + }, + }, + false, + }, + { + "some of messages' heights are not same to the given height", []*proto.Message{ { View: &proto.View{ @@ -636,7 +781,7 @@ func TestMessages_AllHaveSameHeight(t *testing.T) { false, }, { - "same height", + "all of messages' heights is same to the given height", []*proto.Message{ { View: &proto.View{ diff --git a/messages/messages.go b/messages/messages.go index 894834f..bf19ac1 100644 --- a/messages/messages.go +++ b/messages/messages.go @@ -72,7 +72,10 @@ func (ms *Messages) AddMessage(message *proto.Message) { // Append the message to the appropriate queue messages := heightMsgMap.getViewMessages(message.View) messages[string(message.From)] = message +} +// SignalEvent signals event +func (ms *Messages) SignalEvent(message *proto.Message) { ms.eventManager.signalEvent( message.Type, &proto.View{ @@ -82,6 +85,7 @@ func (ms *Messages) AddMessage(message *proto.Message) { ) } +// Close closes event manager func (ms *Messages) Close() { ms.eventManager.close() } diff --git a/messages/messages_test.go b/messages/messages_test.go index 54b6e85..49448af 100644 --- a/messages/messages_test.go +++ b/messages/messages_test.go @@ -131,20 +131,39 @@ func TestMessages_AddDuplicates(t *testing.T) { func TestMessages_Prune(t *testing.T) { t.Parallel() - numMessages := 5 - messageType := proto.MessageType_PREPARE + var ( + numMessages = 5 + messageType = proto.MessageType_PREPARE + + height uint64 = 2 + ) + messages := NewMessages() t.Cleanup(func() { messages.Close() }) - views := make([]*proto.View, 0) - for index := uint64(1); index <= 3; index++ { - views = append(views, &proto.View{ - Height: 1, - Round: index, - }) + views := []*proto.View{ + { + Height: height - 1, + Round: 1, + }, + { + Height: height, + Round: 2, + }, + { + Height: height + 1, + Round: 3, + }, + } + + // expected number of message for each view after pruning + expectedNumMessages := []int{ + 0, + numMessages, + numMessages, } // Append random message types @@ -165,19 +184,22 @@ func TestMessages_Prune(t *testing.T) { } // Prune out the messages from this view - messages.PruneByHeight(views[1].Height + 1) - - // Make sure the round 1 messages are pruned out - assert.Equal(t, 0, messages.numMessages(views[0], messageType)) - - // Make sure the round 2 messages are pruned out - assert.Equal(t, 0, messages.numMessages(views[1], messageType)) - - // Make sure the round 3 messages are pruned out - assert.Equal(t, 0, messages.numMessages(views[2], messageType)) + messages.PruneByHeight(height) + + // check numbers of messages + for idx, expected := range expectedNumMessages { + assert.Equal( + t, + expected, + messages.numMessages( + views[idx], + messageType, + ), + ) + } } -// TestMessages_GetMessage makes sure +// TestMessages_GetValidMessagesMessage_InvalidMessages makes sure // that messages are fetched correctly for the // corresponding message type func TestMessages_GetValidMessagesMessage(t *testing.T) { @@ -188,7 +210,9 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { Height: 1, Round: 0, } - numMessages = 5 + + numMessages = 10 + numValidMessages = 5 ) testTable := []struct { @@ -213,8 +237,14 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { }, } - alwaysInvalidFn := func(_ *proto.Message) bool { - return false + newIsValid := func(numValidMessages int) func(_ *proto.Message) bool { + calls := 0 + + return func(_ *proto.Message) bool { + calls++ + + return calls <= numValidMessages + } } for _, testCase := range testTable { @@ -247,20 +277,23 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { ) // Start fetching messages and making sure they're not cleared - switch testCase.messageType { - case proto.MessageType_PREPREPARE: - messages.GetValidMessages(defaultView, proto.MessageType_PREPREPARE, alwaysInvalidFn) - case proto.MessageType_PREPARE: - messages.GetValidMessages(defaultView, proto.MessageType_PREPARE, alwaysInvalidFn) - case proto.MessageType_COMMIT: - messages.GetValidMessages(defaultView, proto.MessageType_COMMIT, alwaysInvalidFn) - case proto.MessageType_ROUND_CHANGE: - messages.GetValidMessages(defaultView, proto.MessageType_ROUND_CHANGE, alwaysInvalidFn) - } + validMessages := messages.GetValidMessages( + defaultView, + testCase.messageType, + newIsValid(numValidMessages), + ) + + // make sure only valid messages are returned + assert.Len( + t, + validMessages, + numValidMessages, + ) + // make sure invalid messages are pruned assert.Equal( t, - 0, + numMessages-numValidMessages, messages.numMessages(defaultView, testCase.messageType), ) }) @@ -273,42 +306,129 @@ func TestMessages_GetValidMessagesMessage(t *testing.T) { func TestMessages_GetMostRoundChangeMessages(t *testing.T) { t.Parallel() - messages := NewMessages() - defer messages.Close() + tests := []struct { + name string + messages [][]*proto.Message + minRound uint64 + height uint64 + expectedNum int + expectedRound uint64 + }{ + { + name: "should return nil if not found", + messages: [][]*proto.Message{ + generateRandomMessages(3, &proto.View{ + Height: 0, + Round: 1, // smaller than minRound + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 0, + }, + { + name: "should return round change messages if messages' round is greater than/equal to minRound", + messages: [][]*proto.Message{ + generateRandomMessages(1, &proto.View{ + Height: 0, + Round: 2, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 1, + height: 0, + expectedNum: 1, + expectedRound: 2, + }, + { + name: "should return most round change messages (the round is equals to minRound)", + messages: [][]*proto.Message{ + generateRandomMessages(1, &proto.View{ + Height: 0, + Round: 4, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 2, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 2, + expectedRound: 2, + }, + { + name: "should return most round change messages (the round is bigger than minRound)", + messages: [][]*proto.Message{ + generateRandomMessages(3, &proto.View{ + Height: 0, + Round: 1, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 3, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(1, &proto.View{ + Height: 0, + Round: 4, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 2, + expectedRound: 3, + }, + { + name: "should return the first of most round change messages", + messages: [][]*proto.Message{ + generateRandomMessages(3, &proto.View{ + Height: 0, + Round: 1, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 4, + }, proto.MessageType_ROUND_CHANGE), + generateRandomMessages(2, &proto.View{ + Height: 0, + Round: 3, + }, proto.MessageType_ROUND_CHANGE), + }, + minRound: 2, + height: 0, + expectedNum: 2, + expectedRound: 4, + }, + } - mostMessageCount := 3 - mostMessagesRound := uint64(2) + for _, test := range tests { + test := test - // Generate round messages - randomMessages := map[uint64][]*proto.Message{ - 0: generateRandomMessages(mostMessageCount-2, &proto.View{ - Height: 0, - Round: 0, - }, proto.MessageType_ROUND_CHANGE), - 1: generateRandomMessages(mostMessageCount-1, &proto.View{ - Height: 0, - Round: 1, - }, proto.MessageType_ROUND_CHANGE), - mostMessagesRound: generateRandomMessages(mostMessageCount, &proto.View{ - Height: 0, - Round: mostMessagesRound, - }, proto.MessageType_ROUND_CHANGE), - } + t.Run(test.name, func(t *testing.T) { + t.Parallel() - // Add the messages - for _, roundMessages := range randomMessages { - for _, message := range roundMessages { - messages.AddMessage(message) - } - } + messages := NewMessages() + defer messages.Close() - roundChangeMessages := messages.GetMostRoundChangeMessages(0, 0) + // Add the messages + for _, roundMessages := range test.messages { + for _, message := range roundMessages { + messages.AddMessage(message) + } + } - if len(roundChangeMessages) != mostMessageCount { - t.Fatalf("Invalid number of round change messages, %d", len(roundChangeMessages)) - } + roundChangeMessages := messages.GetMostRoundChangeMessages(test.minRound, test.height) - assert.Equal(t, mostMessagesRound, roundChangeMessages[0].View.Round) + if test.expectedNum == 0 { + assert.Nil(t, roundChangeMessages, "should be nil but not nil") + } else { + assert.Len(t, roundChangeMessages, test.expectedNum, "invalid number of round change messages") + } + + for _, msg := range roundChangeMessages { + assert.Equal(t, test.expectedRound, msg.View.Round) + } + }) + } } // TestMessages_EventManager checks that the event manager @@ -341,6 +461,7 @@ func TestMessages_EventManager(t *testing.T) { randomMessages := generateRandomMessages(numMessages, baseView, messageType) for _, message := range randomMessages { messages.AddMessage(message) + messages.SignalEvent(message) } // Wait for the subscription event to happen @@ -352,3 +473,252 @@ func TestMessages_EventManager(t *testing.T) { // Make sure the number of messages is actually accurate assert.Equal(t, numMessages, messages.numMessages(baseView, messageType)) } + +// TestMessages_Unsubscribe checks Messages calls eventManager.cancelSubscription +// in Unsubscribe method +func TestMessages_Unsubscribe(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + numMessages := 10 + messageType := proto.MessageType_PREPARE + baseView := &proto.View{ + Height: 0, + Round: 0, + } + + // Create the subscription + subscription := messages.Subscribe(SubscriptionDetails{ + MessageType: messageType, + View: baseView, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + assert.Equal(t, int64(1), messages.eventManager.numSubscriptions) + + messages.Unsubscribe(subscription.ID) + + assert.Equal(t, int64(0), messages.eventManager.numSubscriptions) +} + +// TestMessages_Unsubscribe checks Messages calls eventManager.close +// in Close method +func TestMessages_Close(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + numMessages := 10 + baseView := &proto.View{ + Height: 0, + Round: 0, + } + + // Create 2 subscriptions + _ = messages.Subscribe(SubscriptionDetails{ + MessageType: proto.MessageType_PREPARE, + View: baseView, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + _ = messages.Subscribe(SubscriptionDetails{ + MessageType: proto.MessageType_COMMIT, + View: baseView, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + assert.Equal(t, int64(2), messages.eventManager.numSubscriptions) + + messages.Close() + + assert.Equal(t, int64(0), messages.eventManager.numSubscriptions) +} + +func TestMessages_getProtoMessage(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + var ( + numMessages = 10 + messageType = proto.MessageType_COMMIT + view = &proto.View{ + Height: 0, + Round: 0, + } + ) + + // Create the subscription + subscription := messages.Subscribe(SubscriptionDetails{ + MessageType: messageType, + View: view, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + defer messages.Unsubscribe(subscription.ID) + + // Push random messages + generatedMessages := generateRandomMessages(numMessages, view, messageType) + messageMap := map[string]*proto.Message{} + + for _, message := range generatedMessages { + messages.AddMessage(message) + messageMap[string(message.From)] = message + } + + // Wait for the subscription event to happen + select { + case <-subscription.SubCh: + case <-time.After(5 * time.Second): + } + + tests := []struct { + name string + view *proto.View + messageType proto.MessageType + expected protoMessages + }{ + { + name: "should return messages for same view and type", + view: view, + messageType: messageType, + expected: messageMap, + }, + { + name: "should return nil for different type", + view: view, + messageType: proto.MessageType_PREPARE, + expected: nil, + }, + { + name: "should return nil for same type and round but different height", + view: &proto.View{ + Height: view.Height + 1, + Round: view.Round, + }, + messageType: messageType, + expected: nil, + }, + { + name: "should return nil for same type and height but different round", + view: &proto.View{ + Height: view.Height, + Round: view.Round + 1, + }, + messageType: messageType, + expected: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + assert.Equal( + t, + test.expected, + messages.getProtoMessages(test.view, test.messageType), + ) + }) + } +} + +func TestMessages_numMessages(t *testing.T) { + t.Parallel() + + messages := NewMessages() + defer messages.Close() + + var ( + numMessages = 10 + messageType = proto.MessageType_COMMIT + view = &proto.View{ + Height: 3, + Round: 5, + } + ) + + // Create the subscription + subscription := messages.Subscribe(SubscriptionDetails{ + MessageType: messageType, + View: view, + HasQuorumFn: func(_ uint64, messages []*proto.Message, _ proto.MessageType) bool { + return len(messages) >= numMessages + }, + }) + + defer messages.Unsubscribe(subscription.ID) + + // Push random messages + for _, message := range generateRandomMessages(numMessages, view, messageType) { + messages.AddMessage(message) + } + + // Wait for the subscription event to happen + select { + case <-subscription.SubCh: + case <-time.After(5 * time.Second): + } + + tests := []struct { + name string + view *proto.View + messageType proto.MessageType + expected int + }{ + { + name: "should return number of messages", + view: view, + messageType: messageType, + expected: numMessages, + }, + { + name: "should return zero if message type is different", + view: view, + messageType: proto.MessageType_PREPARE, + expected: 0, + }, + { + name: "should return zero if height is different", + view: &proto.View{ + Height: 1, + Round: view.Round, + }, + messageType: messageType, + expected: 0, + }, + { + name: "should return zero if round is different", + view: &proto.View{ + Height: view.Height, + Round: 1, + }, + messageType: messageType, + expected: 0, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, test.expected, messages.numMessages(test.view, test.messageType)) + }) + } +}