Skip to content

Commit

Permalink
feat(network): evaluate propagate policy for gossip messages (#1647)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ja7ad authored Dec 22, 2024
1 parent abf3836 commit 1f080fe
Show file tree
Hide file tree
Showing 32 changed files with 400 additions and 145 deletions.
1 change: 1 addition & 0 deletions .github/workflows/semantic-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
other
daemon
cmd
firewall
gtk
shell
wallet-cmd
Expand Down
2 changes: 1 addition & 1 deletion config/example_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
[sync.firewall.rate_limit]

# `block_topic` specifies the rate limit for the block topic.
block_topic = 0
block_topic = 1

# `transaction_topic` specifies the rate limit for the transaction topic.
transaction_topic = 5
Expand Down
5 changes: 5 additions & 0 deletions network/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
lp2pdht "github.com/libp2p/go-libp2p-kad-dht"
lp2pcore "github.com/libp2p/go-libp2p/core"
lp2phost "github.com/libp2p/go-libp2p/core/host"
"github.com/multiformats/go-multiaddr"
"github.com/pactus-project/pactus/util/logger"
)

Expand All @@ -19,6 +20,10 @@ type dhtService struct {
func newDHTService(ctx context.Context, host lp2phost.Host, protocolID lp2pcore.ProtocolID,
conf *Config, log *logger.SubLogger,
) *dhtService {
// A dirty code in LibP2P!!!
// prevent apply default bootstrap node of libp2p
lp2pdht.DefaultBootstrapPeers = []multiaddr.Multiaddr{}

mode := lp2pdht.ModeAuto
if conf.IsBootstrapper {
mode = lp2pdht.ModeServer
Expand Down
51 changes: 22 additions & 29 deletions network/gossip.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func newGossipService(ctx context.Context, host lp2phost.Host, conf *Config,
lp2pps.WithMessageSignaturePolicy(lp2pps.StrictNoSign),
lp2pps.WithNoAuthor(),
lp2pps.WithMessageIdFn(MessageIDFunc),
lp2pps.WithSeenMessagesTTL(60 * time.Second),
}

if conf.IsBootstrapper {
Expand Down Expand Up @@ -114,19 +115,13 @@ func (g *gossipService) publish(msg []byte, topic *lp2pps.Topic) error {
}

// JoinTopic joins to the topic with the given name and subscribes to receive topic messages.
func (g *gossipService) JoinTopic(topicID TopicID, shouldPropagate ShouldPropagate) error {
func (g *gossipService) JoinTopic(topicID TopicID, evaluator PropagationEvaluator) error {
switch topicID {
case TopicIDUnspecified:
return InvalidTopicError{TopicID: topicID}

case TopicIDBlock:
if g.topicBlock != nil {
g.logger.Warn("already subscribed to block topic")

return nil
}

topic, err := g.joinTopic(topicID, shouldPropagate)
topic, err := g.joinTopic(topicID, evaluator)
if err != nil {
return err
}
Expand All @@ -135,13 +130,7 @@ func (g *gossipService) JoinTopic(topicID TopicID, shouldPropagate ShouldPropaga
return nil

case TopicIDTransaction:
if g.topicTransaction != nil {
g.logger.Warn("already subscribed to transaction topic")

return nil
}

topic, err := g.joinTopic(topicID, shouldPropagate)
topic, err := g.joinTopic(topicID, evaluator)
if err != nil {
return err
}
Expand All @@ -150,13 +139,7 @@ func (g *gossipService) JoinTopic(topicID TopicID, shouldPropagate ShouldPropaga
return nil

case TopicIDConsensus:
if g.topicConsensus != nil {
g.logger.Warn("already subscribed to consensus topic")

return nil
}

topic, err := g.joinTopic(topicID, shouldPropagate)
topic, err := g.joinTopic(topicID, evaluator)
if err != nil {
return err
}
Expand All @@ -175,7 +158,7 @@ func (g *gossipService) TopicName(topicID TopicID) string {

// joinTopic joins a given topic and registers a validator for it.
// If successful, it returns the topic and subscribes to it.
func (g *gossipService) joinTopic(topicID TopicID, shouldPropagate ShouldPropagate) (*lp2pps.Topic, error) {
func (g *gossipService) joinTopic(topicID TopicID, evaluator PropagationEvaluator) (*lp2pps.Topic, error) {
topicName := g.TopicName(topicID)
topic, err := g.pubsub.Join(topicName)
if err != nil {
Expand All @@ -188,7 +171,7 @@ func (g *gossipService) joinTopic(topicID TopicID, shouldPropagate ShouldPropaga
}

err = g.pubsub.RegisterTopicValidator(
topicName, g.createValidator(topicID, shouldPropagate))
topicName, g.createValidator(topicID, evaluator))
if err != nil {
return nil, LibP2PError{Err: err}
}
Expand Down Expand Up @@ -221,7 +204,7 @@ func (g *gossipService) joinTopic(topicID TopicID, shouldPropagate ShouldPropaga
return topic, nil
}

func (g *gossipService) createValidator(topicID TopicID, shouldPropagate ShouldPropagate,
func (g *gossipService) createValidator(topicID TopicID, evaluator PropagationEvaluator,
) func(context.Context, lp2pcore.PeerID, *lp2pps.Message) lp2pps.ValidationResult {
return func(_ context.Context, peerId lp2pcore.PeerID, lp2pMsg *lp2pps.Message) lp2pps.ValidationResult {
msg := &GossipMessage{
Expand All @@ -235,16 +218,26 @@ func (g *gossipService) createValidator(topicID TopicID, shouldPropagate ShouldP
return lp2pps.ValidationAccept
}

if !shouldPropagate(msg) {
g.logger.Debug("message ignored", "from", peerId, "topic", topicID)
switch evaluator(msg) {
case Drop:
g.logger.Debug("message dropped", "from", peerId, "topic", topicID)

return lp2pps.ValidationIgnore

case DropButConsume:
g.logger.Debug("message dropped but consumed", "from", peerId, "topic", topicID)

// Consume the message first
g.onReceiveMessage(msg)

return lp2pps.ValidationIgnore
}

return lp2pps.ValidationAccept
case Propagate:
return lp2pps.ValidationAccept

default:
panic("unreachable")
}
}
}

Expand Down
64 changes: 57 additions & 7 deletions network/gossip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ import (
"github.com/stretchr/testify/assert"
)

func TestJoinBlockTopic(t *testing.T) {
net := makeTestNetwork(t, testConfig(), nil)

msg := []byte("test-block-topic")

assert.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDBlock),
NotSubscribedError{
TopicID: TopicIDBlock,
})
assert.NoError(t, net.JoinTopic(TopicIDBlock, alwaysPropagate))
assert.NoError(t, net.gossip.Broadcast(msg, TopicIDBlock))

assert.Error(t, net.JoinTopic(TopicIDBlock, alwaysPropagate), "already joined")
}

func TestJoinConsensusTopic(t *testing.T) {
net := makeTestNetwork(t, testConfig(), nil)

Expand All @@ -21,6 +36,23 @@ func TestJoinConsensusTopic(t *testing.T) {
})
assert.NoError(t, net.JoinTopic(TopicIDConsensus, alwaysPropagate))
assert.NoError(t, net.gossip.Broadcast(msg, TopicIDConsensus))

assert.Error(t, net.JoinTopic(TopicIDConsensus, alwaysPropagate), "already joined")
}

func TestJoinTransactionTopic(t *testing.T) {
net := makeTestNetwork(t, testConfig(), nil)

msg := []byte("test-transaction-topic")

assert.ErrorIs(t, net.gossip.Broadcast(msg, TopicIDTransaction),
NotSubscribedError{
TopicID: TopicIDTransaction,
})
assert.NoError(t, net.JoinTopic(TopicIDTransaction, alwaysPropagate))
assert.NoError(t, net.gossip.Broadcast(msg, TopicIDTransaction))

assert.Error(t, net.JoinTopic(TopicIDTransaction, alwaysPropagate), "already joined")
}

func TestJoinInvalidTopic(t *testing.T) {
Expand Down Expand Up @@ -57,31 +89,49 @@ func TestTopicValidator(t *testing.T) {
net := makeTestNetwork(t, testConfig(), nil)

selfID := net.host.ID()
propagate := false
propagate := Drop
validator := net.gossip.createValidator(TopicIDConsensus,
func(_ *GossipMessage) bool { return propagate })
func(_ *GossipMessage) PropagationPolicy { return propagate })

tests := []struct {
name string
peerID lp2pcore.PeerID
propagate bool
policy PropagationPolicy
expectedResult lp2pps.ValidationResult
}{
{
name: "Message from self",
propagate: false,
policy: Drop,
peerID: selfID,
expectedResult: lp2pps.ValidationAccept,
},
{
name: "Message from self",
policy: DropButConsume,
peerID: selfID,
expectedResult: lp2pps.ValidationAccept,
},
{
name: "Message from self",
policy: propagate,
peerID: selfID,
expectedResult: lp2pps.ValidationAccept,
},
{
name: "Message from other peer, should not propagate",
propagate: false,
policy: Drop,
peerID: "other-peerID",
expectedResult: lp2pps.ValidationIgnore,
},
{
name: "Message from other peer, should not propagate",
policy: DropButConsume,
peerID: "other-peerID",
expectedResult: lp2pps.ValidationIgnore,
},
{
name: "Message from other peer, should propagate",
propagate: true,
policy: Propagate,
peerID: "other-peerID",
expectedResult: lp2pps.ValidationAccept,
},
Expand All @@ -94,7 +144,7 @@ func TestTopicValidator(t *testing.T) {
Data: []byte("some-data"),
},
}
propagate = tt.propagate
propagate = tt.policy
result := validator(context.Background(), tt.peerID, msg)
assert.Equal(t, result, tt.expectedResult)
})
Expand Down
19 changes: 15 additions & 4 deletions network/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,20 @@ func (*ProtocolsEvents) Type() EventType {
return EventTypeProtocols
}

// ShouldPropagate determines whether a message should be disregarded:
// it will be neither delivered to the application nor forwarded to the network.
type ShouldPropagate func(*GossipMessage) bool
// PropagationPolicy defines the possible actions for how a gossip message should propagate.
type PropagationPolicy int

const (
// Propagate means the message should be forwarded to other peers in the network.
Propagate = PropagationPolicy(0)
// DropButConsume means the message should not be forwarded but should be processed locally.
DropButConsume = PropagationPolicy(1)
// Drop means the message should be discarded without any further processing.
Drop = PropagationPolicy(2)
)

// PropagationEvaluator is a function that evaluates how a gossip message should propagate.
type PropagationEvaluator func(*GossipMessage) PropagationPolicy

type Network interface {
Start() error
Expand All @@ -132,7 +143,7 @@ type Network interface {
EventChannel() <-chan Event
Broadcast([]byte, TopicID)
SendTo([]byte, lp2pcore.PeerID)
JoinTopic(TopicID, ShouldPropagate) error
JoinTopic(TopicID, PropagationEvaluator) error
CloseConnection(lp2pcore.PeerID)
SelfID() lp2pcore.PeerID
NumConnectedPeers() int
Expand Down
2 changes: 1 addition & 1 deletion network/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (mock *MockNetwork) EventChannel() <-chan Event {
return mock.EventCh
}

func (*MockNetwork) JoinTopic(_ TopicID, _ ShouldPropagate) error {
func (*MockNetwork) JoinTopic(_ TopicID, _ PropagationEvaluator) error {
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ func (n *network) Broadcast(msg []byte, topicID TopicID) {
}()
}

func (n *network) JoinTopic(topicID TopicID, sp ShouldPropagate) error {
return n.gossip.JoinTopic(topicID, sp)
func (n *network) JoinTopic(topicID TopicID, evaluator PropagationEvaluator) error {
return n.gossip.JoinTopic(topicID, evaluator)
}

func (n *network) CloseConnection(pid lp2ppeer.ID) {
Expand Down
26 changes: 22 additions & 4 deletions network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (
"github.com/stretchr/testify/require"
)

func alwaysPropagate(_ *GossipMessage) bool {
return true
func alwaysPropagate(_ *GossipMessage) PropagationPolicy {
return Propagate
}

func makeTestNetwork(t *testing.T, conf *Config, opts []lp2p.Option) *network {
Expand Down Expand Up @@ -416,6 +416,24 @@ func TestNetwork(t *testing.T) {
})
}

func TestHostAddrs(t *testing.T) {
conf := testConfig()
net, err := NewNetwork(conf)
assert.NoError(t, err)

addrs := net.HostAddrs()
assert.Contains(t, addrs, fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", conf.DefaultPort))
assert.Contains(t, addrs, fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", conf.DefaultPort))
}

func TestNetworkName(t *testing.T) {
conf := testConfig()
net, err := NewNetwork(conf)
assert.NoError(t, err)

assert.Equal(t, conf.NetworkName, net.Name())
}

func TestConnections(t *testing.T) {
t.Parallel() // run the tests in parallel

Expand Down Expand Up @@ -454,12 +472,12 @@ func TestConnections(t *testing.T) {
no, bootstrapAddr, tt.peerAddr), func(t *testing.T) {
t.Parallel() // run the tests in parallel

testConnection(t, networkP, networkB)
checkConnection(t, networkP, networkB)
})
}
}

func testConnection(t *testing.T, networkP, networkB *network) {
func checkConnection(t *testing.T, networkP, networkB *network) {
t.Helper()

assert.EventuallyWithT(t, func(c *assert.CollectT) {
Expand Down
7 changes: 3 additions & 4 deletions state/mock.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package state

import (
"fmt"
"sync"
"time"

Expand Down Expand Up @@ -82,6 +81,9 @@ func (m *MockState) LastBlockHash() hash.Hash {
}

func (m *MockState) LastBlockTime() time.Time {
m.lk.RLock()
defer m.lk.RUnlock()

if len(m.TestStore.Blocks) > 0 {
return m.TestStore.Blocks[m.TestStore.LastHeight].Header().Time()
}
Expand All @@ -104,9 +106,6 @@ func (m *MockState) CommitBlock(blk *block.Block, cert *certificate.BlockCertifi
m.lk.Lock()
defer m.lk.Unlock()

if cert.Height() != m.TestStore.LastHeight+1 {
return fmt.Errorf("invalid height")
}
m.TestStore.SaveBlock(blk, cert)

return nil
Expand Down
Loading

0 comments on commit 1f080fe

Please sign in to comment.