Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Annotated decoder #4934

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cabal.project
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ package network-mux
package ouroboros-network
flags: +asserts +cddl


source-repository-package
type: git
location: https://github.com/input-output-hk/typed-protocols
tag: 9a0acda4cd34e37b53e53986e7a71a76bba2ca8c
subdir: typed-protocols
typed-protocols-cborg
allow-newer: typed-protocols:io-classes

154 changes: 136 additions & 18 deletions ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
-- @UndecidableInstances@ extensions is required for defining @Show@ instance
-- of @'TraceSendRecv'@.
Expand All @@ -19,10 +18,12 @@ module Ouroboros.Network.Driver.Simple
-- $intro
-- * Normal peers
runPeer
, runAnnotatedPeer
, TraceSendRecv (..)
, DecoderFailure (..)
-- * Pipelined peers
, runPipelinedPeer
, runPipelinedAnnotatedPeer
-- * Connected peers
-- TODO: move these to a test lib
, Role (..)
Expand All @@ -43,6 +44,9 @@ import Ouroboros.Network.Channel
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Tracer (Tracer (..), contramap, traceWith)
import Data.Maybe (fromMaybe)
import Data.Functor.Identity (Identity)
import Control.Monad.Identity (Identity(..))


-- $intro
Expand Down Expand Up @@ -107,18 +111,31 @@ instance Show DecoderFailure where
instance Exception DecoderFailure where


driverSimple :: forall ps failure bytes m.
( MonadThrow m
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m
driverSimple tracer Codec{encode, decode} channel@Channel{send} =
mkSimpleDriver :: forall ps failure bytes m f annotator.
( MonadThrow m
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> (forall a.
Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m (f a)
-> m (Either failure (a, Maybe bytes))
)
-- ^ run incremental decoder against a channel

-> (forall st. annotator st -> f (SomeMessage st))
-- ^ transform annotator to a container holding the decoded
-- message

-> Tracer m (TraceSendRecv ps)
-> Codec' ps failure m annotator bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m

mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{send} =
Driver { sendMessage, recvMessage, startDState = Nothing }
where
sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps).
Expand All @@ -135,7 +152,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
-> m (SomeMessage st, Maybe bytes)
recvMessage stok trailing = do
decoder <- decode stok
result <- runDecoderWithChannel channel trailing decoder
result <- runDecodeSteps channel trailing (nat <$> decoder)
case result of
Right x@(SomeMessage msg, _trailing') -> do
traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg))
Expand All @@ -144,6 +161,36 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
throwIO (DecoderFailure stok failure)


simpleDriver :: forall ps failure bytes m.
( MonadThrow m
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m
simpleDriver = mkSimpleDriver runDecoderWithChannel Identity


annotatedSimpleDriver
:: forall ps failure bytes m.
( MonadThrow m
, Monoid bytes
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> AnnotatedCodec ps failure m bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m
annotatedSimpleDriver = mkSimpleDriver runAnnotatedDecoderWithChannel runAnnotator


-- | Run a peer with the given channel via the given codec.
--
-- This runs the peer to completion (if the protocol allows for termination).
Expand All @@ -164,7 +211,31 @@ runPeer
runPeer tracer codec channel peer =
runPeerWithDriver driver peer (startDState driver)
where
driver = driverSimple tracer codec channel
driver = simpleDriver tracer codec channel


-- | Run a peer with the given channel via the given annotated codec.
--
-- This runs the peer to completion (if the protocol allows for termination).
--
runAnnotatedPeer
:: forall ps (st :: ps) pr failure bytes m a .
( MonadThrow m
, Monoid bytes
, Show failure
, forall (st' :: ps). Show (ClientHasAgency st')
, forall (st' :: ps). Show (ServerHasAgency st')
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> AnnotatedCodec ps failure m bytes
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runAnnotatedPeer tracer codec channel peer =
runPeerWithDriver driver peer (startDState driver)
where
driver = annotatedSimpleDriver tracer codec channel


-- | Run a pipelined peer with the given channel via the given codec.
Expand All @@ -191,7 +262,35 @@ runPipelinedPeer
runPipelinedPeer tracer codec channel peer =
runPipelinedPeerWithDriver driver peer (startDState driver)
where
driver = driverSimple tracer codec channel
driver = simpleDriver tracer codec channel


-- | Run a pipelined peer with the given channel via the given annotated codec.
--
-- This runs the peer to completion (if the protocol allows for termination).
--
-- Unlike normal peers, running pipelined peers rely on concurrency, hence the
-- 'MonadAsync' constraint.
--
runPipelinedAnnotatedPeer
:: forall ps (st :: ps) pr failure bytes m a.
( MonadAsync m
, MonadThrow m
, Monoid bytes
, Show failure
, forall (st' :: ps). Show (ClientHasAgency st')
, forall (st' :: ps). Show (ServerHasAgency st')
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> AnnotatedCodec ps failure m bytes
-> Channel m bytes
-> PeerPipelined ps pr st m a
-> m (a, Maybe bytes)
runPipelinedAnnotatedPeer tracer codec channel peer =
runPipelinedPeerWithDriver driver peer (startDState driver)
where
driver = annotatedSimpleDriver tracer codec channel


--
Expand All @@ -204,17 +303,36 @@ runPipelinedPeer tracer codec channel peer =
runDecoderWithChannel :: Monad m
=> Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m a
-> DecodeStep bytes failure m (Identity a)
-> m (Either failure (a, Maybe bytes))

runDecoderWithChannel Channel{recv} = go
where
go _ (DecodeDone x trailing) = return (Right (x, trailing))
go _ (DecodeDone (Identity x) trailing) = return (Right (x, trailing))
go _ (DecodeFail failure) = return (Left failure)
go Nothing (DecodePartial k) = recv >>= k >>= go Nothing
go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing


runAnnotatedDecoderWithChannel
:: forall m bytes failure a.
( Monad m
, Monoid bytes
)
=> Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m (bytes -> a)
-> m (Either failure (a, Maybe bytes))

runAnnotatedDecoderWithChannel Channel{recv} bs0 = go (fromMaybe mempty bs0) bs0
where
go :: bytes -> Maybe bytes -> DecodeStep bytes failure m (bytes -> a) -> m (Either failure (a, Maybe bytes))
go bytes _ (DecodeDone f trailing) = return $ Right (f bytes, trailing)
go _bytes _ (DecodeFail failure) = return (Left failure)
go bytes Nothing (DecodePartial k) = recv >>= \bs -> k bs >>= go (bytes <> fromMaybe mempty bs) Nothing
go bytes (Just trailing) (DecodePartial k) = k (Just trailing) >>= go (bytes <> trailing) Nothing


data Role = Client | Server

-- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,30 @@ timeLimitsTxSubmission2 = ProtocolTimeLimits stateToLimit


codecTxSubmission2
:: forall txid tx m.
:: forall txid tx annotator m.
MonadST m
=> (txid -> CBOR.Encoding)
-> (forall s . CBOR.Decoder s txid)
-> (tx -> CBOR.Encoding)
-> (forall s . CBOR.Decoder s tx)
-> Codec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString
-- the codec is polymorphic in annotator. The primary use case is an
-- `Identity` functor or `Annotator LBS.ByteString`.
-> (forall st. SomeMessage st -> annotator st)
-> Codec' (TxSubmission2 txid tx) CBOR.DeserialiseFailure m annotator ByteString
codecTxSubmission2 encodeTxId decodeTxId
encodeTx decodeTx =
encodeTx decodeTx
annotate =
mkCodecCborLazyBS
(encodeTxSubmission2 encodeTxId encodeTx)
decode
where
decode :: forall (pr :: PeerRole) (st :: TxSubmission2 txid tx).
PeerHasAgency pr st
-> forall s. CBOR.Decoder s (SomeMessage st)
-> forall s. CBOR.Decoder s (annotator st)
decode stok = do
len <- CBOR.decodeListLen
key <- CBOR.decodeWord
decodeTxSubmission2 decodeTxId decodeTx stok len key
decodeTxSubmission2 decodeTxId decodeTx annotate stok len key

encodeTxSubmission2
:: forall txid tx.
Expand Down Expand Up @@ -149,30 +153,31 @@ encodeTxSubmission2 encodeTxId encodeTx = encode


decodeTxSubmission2
:: forall txid tx.
:: forall txid tx annotator.
(forall s . CBOR.Decoder s txid)
-> (forall s . CBOR.Decoder s tx)
-> (forall st. SomeMessage st -> annotator st)
-> (forall (pr :: PeerRole) (st :: TxSubmission2 txid tx) s.
PeerHasAgency pr st
-> Int
-> Word
-> CBOR.Decoder s (SomeMessage st))
decodeTxSubmission2 decodeTxId decodeTx = decode
-> CBOR.Decoder s (annotator st))
decodeTxSubmission2 decodeTxId decodeTx annotate = decode
where
decode :: forall (pr :: PeerRole) s (st :: TxSubmission2 txid tx).
PeerHasAgency pr st
-> Int
-> Word
-> CBOR.Decoder s (SomeMessage st)
-> CBOR.Decoder s (annotator st)
decode stok len key = do
case (stok, len, key) of
(ClientAgency TokInit, 1, 6) ->
return (SomeMessage MsgInit)
return (annotate $ SomeMessage MsgInit)
(ServerAgency TokIdle, 4, 0) -> do
blocking <- CBOR.decodeBool
ackNo <- NumTxIdsToAck <$> CBOR.decodeWord16
reqNo <- NumTxIdsToReq <$> CBOR.decodeWord16
return $!
return $! annotate $
if blocking
then SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo)
else SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo)
Expand All @@ -187,11 +192,11 @@ decodeTxSubmission2 decodeTxId decodeTx = decode
return (txid, SizeInBytes sz))
case (b, txids) of
(TokBlocking, t:ts) ->
return $
return $ annotate $
SomeMessage (MsgReplyTxIds (BlockingReply (t NonEmpty.:| ts)))

(TokNonBlocking, ts) ->
return $
return $ annotate $
SomeMessage (MsgReplyTxIds (NonBlockingReply ts))

(TokBlocking, []) ->
Expand All @@ -201,15 +206,26 @@ decodeTxSubmission2 decodeTxId decodeTx = decode
(ServerAgency TokIdle, 2, 2) -> do
CBOR.decodeListLenIndef
txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTxId
return (SomeMessage (MsgRequestTxs txids))
return (annotate $ SomeMessage (MsgRequestTxs txids))

(ClientAgency TokTxs, 2, 3) -> do
CBOR.decodeListLenIndef
txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTx
return (SomeMessage (MsgReplyTxs txids))
-- ^ TODO: `txids -> txs` :grin:

-- TODO: here we have access to bytes from which the message was decoded.
-- we can use `Codec.CBOR.Decoding.decodeWithByteSpan`
-- around each `tx` and wrap each `tx` in `WithBytes`.
--
-- `decodeTxSubmission2` can be polymorphic by adding an
-- extra argument of type
-- `ByteString -> ByteOffSet -> ByteOffset -> tx -> a`
-- this way we could wrap `tx` in `WithBytes` or just
-- return `tx`.
return (annotate $ SomeMessage (MsgReplyTxs txids))

(ClientAgency (TokTxIds TokBlocking), 1, 4) ->
return (SomeMessage MsgDone)
return (annotate $ SomeMessage MsgDone)

--
-- failures per protocol state
Expand Down
Loading