Skip to content

Commit

Permalink
tx-submission: verify tx sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
coot committed Sep 23, 2024
1 parent d900a38 commit 388cc69
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ import Ouroboros.Network.TxSubmission.Inbound.Types (TraceTxLogic,
TraceTxSubmissionInbound)
import Ouroboros.Network.TxSubmission.Outbound (txSubmissionOutbound)
import Test.Ouroboros.Network.Diffusion.Node.NodeKernel
import Test.Ouroboros.Network.TxSubmission.Types (Mempool, Tx, getMempoolReader,
import Test.Ouroboros.Network.TxSubmission.Types (Mempool, Tx (..), getMempoolReader,
getMempoolWriter, txSubmissionCodec2)


Expand Down Expand Up @@ -684,6 +684,7 @@ applications debugTracer txSubmissionInboundTracer txSubmissionInboundDebug node
txChannelsVar
sharedTxStateVar
(getMempoolReader mempool)
getTxSize
them $ \api -> do
let server = txSubmissionInboundV2
txSubmissionInboundTracer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ runTxSubmission tracer tracerTxLogic state txDecisionPolicy = do
txChannelsVar
sharedTxStateVar
(getMempoolReader inboundMempool)
getTxSize
addr $ \api -> do
let server = txSubmissionInboundV2 verboseTracer
(getMempoolWriter inboundMempool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ import Data.Monoid (Sum (..))
import Data.Sequence.Strict qualified as StrictSeq
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Typeable

import NoThunks.Class

import Ouroboros.Network.Protocol.TxSubmission2.Type
import Ouroboros.Network.TxSubmission.Inbound.Decision
(SharedDecisionContext (..), TxDecision (..))
import Ouroboros.Network.TxSubmission.Inbound.Types qualified as TXS
import Ouroboros.Network.TxSubmission.Inbound.Decision qualified as TXS
import Ouroboros.Network.TxSubmission.Inbound.Policy
import Ouroboros.Network.TxSubmission.Inbound.State (PeerTxState (..),
Expand Down Expand Up @@ -305,7 +307,7 @@ mkArbPeerTxState mempoolHasTxFun txIdsInflight unacked txMaskMap =
where
mempoolHasTx = apply mempoolHasTxFun
availableTxIds = Map.fromList
[ (txid, getTxSize tx) | (txid, TxAvailable tx _) <- Map.assocs txMaskMap
[ (txid, getTxAdvSize tx) | (txid, TxAvailable tx _) <- Map.assocs txMaskMap
, not (mempoolHasTx txid)
]
unknownTxs = Set.fromList
Expand All @@ -314,7 +316,7 @@ mkArbPeerTxState mempoolHasTxFun txIdsInflight unacked txMaskMap =
]

requestedTxIdsInflight = fromIntegral txIdsInflight
requestedTxsInflightSize = foldMap getTxSize inflightMap
requestedTxsInflightSize = foldMap getTxAdvSize inflightMap
requestedTxsInflight = Map.keysSet inflightMap

-- exclude `txid`s which are already in the mempool, we never request such
Expand Down Expand Up @@ -758,12 +760,20 @@ instance Arbitrary ArbCollectTxs where

receivedTx <- sublistOf requestedTxIds'
>>= traverse (\txid -> do
-- real size, which might be different from
-- the advertised size
size <- frequency [ (9, pure (availableTxIds Map.! txid))
, (1, chooseEnum (0, maxTxSize))
]

valid <- frequency [(4, pure True), (1, pure False)]
pure $ Tx { getTxId = txid,
getTxSize = availableTxIds Map.! txid,
getTxValid = valid })
pure $ Tx { getTxId = txid,
getTxSize = size,
-- `availableTxIds` contains advertised sizes
getTxAdvSize = availableTxIds Map.! txid,
getTxValid = valid })

pure $ assert (foldMap getTxSize receivedTx <= requestedTxsInflightSize)
pure $ assert (foldMap getTxAdvSize receivedTx <= requestedTxsInflightSize)
$ ArbCollectTxs mempoolHasTxFun
(Set.fromList requestedTxIds')
(Map.fromList [ (getTxId tx, tx) | tx <- receivedTx ])
Expand Down Expand Up @@ -855,24 +865,49 @@ prop_collectTxsImpl (ArbCollectTxs _mempoolHasTxFun txidsRequested txsReceived p
label ("number of txids inflight " ++ labelInt 25 5 (Map.size $ inflightTxs st)) $
label ("number of txids requested " ++ labelInt 25 5 (Set.size txidsRequested)) $
label ("number of txids received " ++ labelInt 10 2 (Map.size txsReceived)) $

-- InboundState invariant
counterexample
( "InboundState invariant violation:\n" ++ show st' ++ "\n"
++ show ps'
)
(sharedTxStateInvariant st')

.&&.
-- `collectTxsImpl` doesn't modify unacknowledged TxId's
counterexample "acknowledged property violation"
( let unacked = toList $ unacknowledgedTxIds ps
unacked' = toList $ unacknowledgedTxIds ps'
in unacked === unacked'
)
label ("hasTxSizeError " ++ show hasTxSizeErr) $

case TXS.collectTxsImpl getTxSize peeraddr txidsRequested txsReceived st of
Right st' | not hasTxSizeErr ->
let ps' = peerTxStates st' Map.! peeraddr in
-- InboundState invariant
counterexample
( "InboundState invariant violation:\n" ++ show st' ++ "\n"
++ show ps'
)
(sharedTxStateInvariant st')

.&&.
-- `collectTxsImpl` doesn't modify unacknowledged TxId's
counterexample "acknowledged property violation"
( let unacked = toList $ unacknowledgedTxIds ps
unacked' = toList $ unacknowledgedTxIds ps'
in unacked === unacked'
)

Right _ ->
counterexample "collectTxsImpl should return Left"
. counterexample (show txsReceived)
$ False
Left _ | not hasTxSizeErr ->
counterexample "collectTxsImpl should return Right" False

Left (TXS.ProtocolErrorTxSizeError as) ->
counterexample (show as)
$ Set.fromList ((\(txid, _, _) -> coerceTxId txid) `map` as)
===
Map.keysSet (Map.filter (\tx -> getTxSize tx /= getTxAdvSize tx) txsReceived)
Left e ->
counterexample ("unexpected error: " ++ show e) False
where
st' = TXS.collectTxsImpl peeraddr txidsRequested txsReceived st
ps' = peerTxStates st' Map.! peeraddr
hasTxSizeErr = any (\tx -> getTxSize tx /= getTxAdvSize tx) txsReceived

-- The `ProtocolErrorTxSizeError` type is an existential type. We know that
-- the type of `txid` is `TxId`, we just don't have evidence for it.
coerceTxId :: Typeable txid => txid -> TxId
coerceTxId txid = case cast txid of
Just a -> a
Nothing -> error "impossible happened! Is the test still using `TxId` for `txid`?"


-- | Verify that `SharedTxState` returned by `collectTxsImpl` if evaluated to
Expand All @@ -882,11 +917,11 @@ prop_collectTxsImpl_nothunks
:: ArbCollectTxs
-> Property
prop_collectTxsImpl_nothunks (ArbCollectTxs _mempoolHasTxFun txidsRequested txsReceived peeraddr _ st) =
case unsafeNoThunks $! st' of
Nothing -> property True
Just ctx -> counterexample (show ctx) False
where
st' = TXS.collectTxsImpl peeraddr txidsRequested txsReceived st
case TXS.collectTxsImpl getTxSize peeraddr txidsRequested txsReceived st of
Right st' -> case unsafeNoThunks $! st' of
Nothing -> property True
Just ctx -> counterexample (show ctx) False
Left _ -> property True


newtype ArbTxDecisionPolicy = ArbTxDecisionPolicy TxDecisionPolicy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ import Text.Printf


data Tx txid = Tx {
getTxId :: !txid,
getTxSize :: !SizeInBytes,
getTxId :: !txid,
getTxSize :: !SizeInBytes,
getTxAdvSize :: !SizeInBytes,
-- | If false this means that when this tx will be submitted to a remote
-- mempool it will not be valid. The outbound mempool might contain
-- invalid tx's in this sense.
getTxValid :: !Bool
getTxValid :: !Bool
}
deriving (Eq, Ord, Show, Generic)

Expand All @@ -69,13 +70,17 @@ instance ShowProxy txid => ShowProxy (Tx txid) where
showProxy _ = "Tx " ++ showProxy (Proxy :: Proxy txid)

instance Arbitrary txid => Arbitrary (Tx txid) where
arbitrary =
arbitrary = do
-- note:
-- generating small tx sizes avoids overflow error when semigroup
-- instance of `SizeInBytes` is used (summing up all inflight tx
-- sizes).
(size, advSize) <- frequency [ (9, (\a -> (a,a)) <$> chooseEnum (0, maxTxSize))
, (1, (,) <$> chooseEnum (0, maxTxSize) <*> chooseEnum (0, maxTxSize))
]
Tx <$> arbitrary
<*> chooseEnum (0, maxTxSize)
-- note:
-- generating small tx sizes avoids overflow error when semigroup
-- instance of `SizeInBytes` is used (summing up all inflight tx
-- sizes).
<*> pure size
<*> pure advSize
<*> frequency [ (3, pure True)
, (1, pure False)
]
Expand Down Expand Up @@ -167,15 +172,17 @@ txSubmissionCodec2 =
codecTxSubmission2 CBOR.encodeInt CBOR.decodeInt
encodeTx decodeTx
where
encodeTx Tx {getTxId, getTxSize, getTxValid} =
CBOR.encodeListLen 3
encodeTx Tx {getTxId, getTxSize, getTxAdvSize, getTxValid} =
CBOR.encodeListLen 4
<> CBOR.encodeInt getTxId
<> CBOR.encodeWord32 (getSizeInBytes getTxSize)
<> CBOR.encodeWord32 (getSizeInBytes getTxAdvSize)
<> CBOR.encodeBool getTxValid

decodeTx = do
_ <- CBOR.decodeListLen
Tx <$> CBOR.decodeInt
<*> (SizeInBytes <$> CBOR.decodeWord32)
<*> (SizeInBytes <$> CBOR.decodeWord32)
<*> CBOR.decodeBool

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Data.Foldable (traverse_
, foldl'
#endif
)
import Data.Typeable (Typeable)
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import Data.Maybe (fromMaybe)
Expand Down Expand Up @@ -75,7 +76,7 @@ data PeerTxAPI m txid tx = PeerTxAPI {
-- ^ requested txids
-> Map txid tx
-- ^ received txs
-> m ()
-> m (Maybe TxSubmissionProtocolError)
-- ^ handle received txs
}

Expand All @@ -90,13 +91,16 @@ withPeer
, MonadMVar m
, MonadSTM m
, Ord txid
, Typeable txid
, Show txid
, Ord peeraddr
, Show peeraddr
)
=> Tracer m (TraceTxLogic peeraddr txid tx)
-> TxChannelsVar m peeraddr txid tx
-> SharedTxStateVar m peeraddr txid tx
-> TxSubmissionMempoolReader txid tx idx m
-> (tx -> SizeInBytes)
-> peeraddr
-- ^ new peer
-> (PeerTxAPI m txid tx -> m a)
Expand All @@ -106,6 +110,7 @@ withPeer tracer
channelsVar
sharedStateVar
TxSubmissionMempoolReader { mempoolGetSnapshot }
txSize
peeraddr io =
bracket
(do -- create a communication channel
Expand Down Expand Up @@ -209,9 +214,9 @@ withPeer tracer
-- ^ requested txids
-> Map txid tx
-- ^ received txs
-> m ()
-> m (Maybe TxSubmissionProtocolError)
handleReceivedTxs txids txs =
collectTxs tracer sharedStateVar peeraddr txids txs
collectTxs tracer txSize sharedStateVar peeraddr txids txs


decisionLogicThread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ txSubmissionInboundV2

unless (Map.keysSet received `Set.isSubsetOf` requested) $
throwIO ProtocolErrorTxNotRequested
-- TODO: all sizes of txs which were announced earlier with
-- `MsgReplyTxIds` must be verified.

handleReceivedTxs requested received
k
mbe <- handleReceivedTxs requested received
case mbe of
-- one of `tx`s had a wrong size
Just e -> throwIO e
Nothing -> k
Loading

0 comments on commit 388cc69

Please sign in to comment.