Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
todo: tests
  • Loading branch information
crocodile-dentist committed Sep 5, 2024
1 parent 4b47e98 commit b61e8b1
Show file tree
Hide file tree
Showing 23 changed files with 202 additions and 264 deletions.
3 changes: 1 addition & 2 deletions cabal.project
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ package ouroboros-network
source-repository-package
type: git
location: https://github.com/input-output-hk/typed-protocols
tag: d0c0668048be5b9878917180d7a0641861216bec
tag: 0d9a6db086d5e7c08204a5b3a236b7126969470c
subdir: typed-protocols
typed-protocols-cborg
allow-newer: typed-protocols:io-classes

13 changes: 12 additions & 1 deletion ouroboros-network-api/src/Ouroboros/Network/SizeInBytes.hs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}

module Ouroboros.Network.SizeInBytes (SizeInBytes (..)) where
module Ouroboros.Network.SizeInBytes (
SizeInBytes (..)
, WithBytes (..)) where

import Control.DeepSeq (NFData (..))
import Data.ByteString.Short (ShortByteString)
import Data.Monoid (Sum (..))
import Data.Word (Word32)
import GHC.Generics
Expand All @@ -14,6 +18,13 @@ import Data.Measure qualified as Measure
import NoThunks.Class (NoThunks (..))
import Quiet (Quiet (..))

data WithBytes a = WithBytes { wbValue :: !a,
unannotate :: ShortByteString }
deriving (Eq, Show)

instance NFData a => NFData (WithBytes a) where
rnf (WithBytes a b) = rnf a `seq` rnf b

newtype SizeInBytes = SizeInBytes { getSizeInBytes :: Word32 }
deriving (Eq, Ord)
deriving Show via Quiet SizeInBytes
Expand Down
84 changes: 49 additions & 35 deletions ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,30 @@ import Ouroboros.Network.Protocol.Limits
import Ouroboros.Network.Util.ShowProxy


driverWithLimits :: forall ps failure bytes m.
driverWithLimits :: forall ps failure bytes m f annotator.
( MonadThrow m
, Show failure
, ShowProxy ps
, forall (st' :: ps). Show (ClientHasAgency st')
, forall (st' :: ps). Show (ServerHasAgency st')
, Monoid bytes
)
=> Tracer m (TraceSendRecv ps)
-> TimeoutFn m
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Driver ps (Maybe bytes) m
=> Tracer m (TraceSendRecv ps)
-> TimeoutFn m
-> Codec ps failure m annotator bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> (forall st. annotator st -> bytes -> SomeMessage st)
-- ^ project out the byte consuming function which produces
-- the decoded message
-> Driver ps (Maybe bytes) m

driverWithLimits tracer timeoutFn
Codec{encode, decode}
ProtocolSizeLimits{sizeLimitForState, dataSize}
ProtocolTimeLimits{timeLimitForState}
channel@Channel{send} =
channel@Channel{send} runAnnotator =
Driver { sendMessage, recvMessage, startDState = Nothing }
where
sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps).
Expand All @@ -84,8 +89,8 @@ driverWithLimits tracer timeoutFn
let sizeLimit = sizeLimitForState stok
timeLimit = fromMaybe (-1) (timeLimitForState stok)
result <- timeoutFn timeLimit $
runDecoderWithLimit sizeLimit dataSize
channel trailing decoder
runDecoderWithLimit sizeLimit dataSize channel
runAnnotator trailing decoder
case result of
Just (Right x@(SomeMessage msg, _trailing')) -> do
traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg))
Expand All @@ -95,17 +100,19 @@ driverWithLimits tracer timeoutFn
Nothing -> throwIO (ExceededTimeLimit stok)

runDecoderWithLimit
:: forall m bytes failure a. Monad m
:: forall m bytes failure annotator st.
(Monad m, Monoid bytes)
=> Word
-- ^ message size limit
-> (bytes -> Word)
-- ^ byte size
-> Channel m bytes
-> (forall st. annotator st -> bytes -> SomeMessage st)
-> Maybe bytes
-> DecodeStep bytes failure m a
-> m (Either (Maybe failure) (a, Maybe bytes))
runDecoderWithLimit limit size Channel{recv} =
go 0
-> DecodeStep bytes failure m (annotator st)
-> m (Either (Maybe failure) (SomeMessage st, Maybe bytes))
runDecoderWithLimit limit size Channel{recv} runAnnotator =
go 0 <*> fromMaybe mempty
where
-- Our strategy here is as follows...
--
Expand All @@ -121,29 +128,30 @@ runDecoderWithLimit limit size Channel{recv} =
-- final chunk, we must check if it consumed too much of the final chunk.
--
go :: Word -- ^ size of consumed input so far
-> Maybe bytes -- ^ any trailing data
-> DecodeStep bytes failure m a
-> m (Either (Maybe failure) (a, Maybe bytes))
-> Maybe bytes -- ^ any queued data
-> bytes -- ^ consumed so far
-> DecodeStep bytes failure m (annotator st)
-> m (Either (Maybe failure) (SomeMessage st, Maybe bytes))

go !sz _ (DecodeDone x trailing)
go !sz _ consumed (DecodeDone annotator trailing)
| let sz' = sz - maybe 0 size trailing
, sz' > limit = return (Left Nothing)
| otherwise = return (Right (x, trailing))
| otherwise = return (Right (runAnnotator annotator consumed, trailing))

go !_ _ (DecodeFail failure) = return (Left (Just failure))
go !_ _ _ (DecodeFail failure) = return (Left (Just failure))

go !sz trailing (DecodePartial k)
go !sz queued consumed (DecodePartial k)
| sz > limit = return (Left Nothing)
| otherwise = case trailing of
| otherwise = case queued of
Nothing -> do mbs <- recv
let !sz' = sz + maybe 0 size mbs
go sz' Nothing =<< k mbs
Just bs -> do let sz' = sz + size bs
go sz' Nothing =<< k (Just bs)
go sz' Nothing (consumed <> fromMaybe mempty mbs) =<< k mbs
Just queued' -> do let sz' = sz + size queued'
go sz' Nothing (consumed <> queued') =<< k queued


runPeerWithLimits
:: forall ps (st :: ps) pr failure bytes m a .
:: forall ps (st :: ps) pr failure bytes annotator m a.
( MonadAsync m
, MonadFork m
, MonadMask m
Expand All @@ -153,17 +161,20 @@ runPeerWithLimits
, forall (st' :: ps). Show (ServerHasAgency st')
, ShowProxy ps
, Show failure
, Monoid bytes
)
=> Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Codec ps failure m annotator bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> (forall st. annotator st -> bytes -> SomeMessage st)
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits tracer codec slimits tlimits channel peer =
runPeerWithLimits tracer codec slimits tlimits channel runAnnotator peer =
withTimeoutSerial $ \timeoutFn ->
let driver = driverWithLimits tracer timeoutFn codec slimits tlimits channel
let driver = driverWithLimits tracer timeoutFn codec slimits
tlimits channel runAnnotator
in runPeerWithDriver driver peer (startDState driver)


Expand All @@ -175,7 +186,7 @@ runPeerWithLimits tracer codec slimits tlimits channel peer =
-- 'MonadAsync' constraint.
--
runPipelinedPeerWithLimits
:: forall ps (st :: ps) pr failure bytes m a.
:: forall ps (st :: ps) pr failure bytes annotator m a.
( MonadAsync m
, MonadFork m
, MonadMask m
Expand All @@ -185,15 +196,18 @@ runPipelinedPeerWithLimits
, forall (st' :: ps). Show (ServerHasAgency st')
, ShowProxy ps
, Show failure
, Monoid bytes
)
=> Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Codec ps failure m annotator bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> (forall st. annotator st -> bytes -> SomeMessage st)
-> PeerPipelined ps pr st m a
-> m (a, Maybe bytes)
runPipelinedPeerWithLimits tracer codec slimits tlimits channel peer =
runPipelinedPeerWithLimits tracer codec slimits tlimits channel runAnnotator peer =
withTimeoutSerial $ \timeoutFn ->
let driver = driverWithLimits tracer timeoutFn codec slimits tlimits channel
let driver = driverWithLimits tracer timeoutFn codec
slimits tlimits channel runAnnotator
in runPipelinedPeerWithDriver driver peer (startDState driver)
Loading

0 comments on commit b61e8b1

Please sign in to comment.