Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] custom packet serializer implementation #421

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions engineio/src/asynchronous/async_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use tokio::{runtime::Handle, sync::Mutex, time::Instant};
use crate::{
asynchronous::{callback::OptionalCallback, transport::AsyncTransportType},
error::Result,
packet::{HandshakePacket, Payload},
packet::{HandshakePacket, PacketSerializer},
Error, Packet, PacketId,
};

Expand All @@ -24,6 +24,7 @@ pub struct Socket {
handle: Handle,
transport: Arc<Mutex<AsyncTransportType>>,
transport_raw: AsyncTransportType,
serializer: Arc<PacketSerializer>,
on_close: OptionalCallback<()>,
on_data: OptionalCallback<Bytes>,
on_error: OptionalCallback<String>,
Expand All @@ -39,6 +40,7 @@ pub struct Socket {
impl Socket {
pub(crate) fn new(
transport: AsyncTransportType,
serializer: Arc<PacketSerializer>,
handshake: HandshakePacket,
on_close: OptionalCallback<()>,
on_data: OptionalCallback<Bytes>,
Expand All @@ -57,6 +59,7 @@ impl Socket {
on_packet,
transport: Arc::new(Mutex::new(transport.clone())),
transport_raw: transport,
serializer,
connected: Arc::new(AtomicBool::default()),
last_ping: Arc::new(Mutex::new(Instant::now())),
last_pong: Arc::new(Mutex::new(Instant::now())),
Expand Down Expand Up @@ -117,9 +120,13 @@ impl Socket {
}

/// Helper method that parses bytes and returns an iterator over the elements.
fn parse_payload(bytes: Bytes) -> impl Stream<Item = Result<Packet>> {
fn parse_payload(
bytes: Bytes,
serializer: Arc<PacketSerializer>,
) -> impl Stream<Item = Result<Packet>> {
try_stream! {
let payload = Payload::try_from(bytes);
// let payload = Payload::try_from(bytes);
let payload = serializer.decode_payload(bytes);

for elem in payload?.into_iter() {
yield elem;
Expand All @@ -131,12 +138,13 @@ impl Socket {
/// underlying transport types.
fn stream(
mut transport: AsyncTransportType,
serialzer: Arc<PacketSerializer>,
) -> Pin<Box<impl Stream<Item = Result<Packet>> + 'static + Send>> {
// map the byte stream of the underlying transport
// to a packet stream
Box::pin(try_stream! {
for await payload in transport.as_pin_box() {
for await packet in Self::parse_payload(payload?) {
for await packet in Self::parse_payload(payload?, serialzer.clone()) {
yield packet?;
}
}
Expand Down Expand Up @@ -172,7 +180,8 @@ impl Socket {
let data: Bytes = if is_binary {
packet.data
} else {
packet.into()
// packet.into()
self.serializer.encode(packet)
};

let lock = self.transport.lock().await;
Expand Down Expand Up @@ -249,7 +258,7 @@ impl Socket {
&'a self,
) -> Pin<Box<dyn Stream<Item = Result<Packet>> + Send + 'a>> {
stream::unfold(
Self::stream(self.transport_raw.clone()),
Self::stream(self.transport_raw.clone(), self.serializer.clone()),
|mut stream| async {
// Wait for the next payload or until we should have received the next ping.
match tokio::time::timeout(
Expand Down
18 changes: 2 additions & 16 deletions engineio/src/asynchronous/async_transports/polling.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use adler32::adler32;
use async_stream::try_stream;
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use bytes::{BufMut, Bytes, BytesMut};
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use http::HeaderMap;
use native_tls::TlsConnector;
Expand Down Expand Up @@ -102,23 +101,10 @@ impl Stream for PollingTransport {
#[async_trait]
impl AsyncTransport for PollingTransport {
async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
let data_to_send = if is_binary_att {
// the binary attachment gets `base64` encoded
let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
packet_bytes.put_u8(b'b');

let encoded_data = general_purpose::STANDARD.encode(data);
packet_bytes.put(encoded_data.as_bytes());

packet_bytes.freeze()
} else {
data
};

let status = self
.client
.post(self.address().await?)
.body(data_to_send)
.body(data)
.send()
.await?
.status()
Expand Down
16 changes: 13 additions & 3 deletions engineio/src/asynchronous/async_transports/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use crate::asynchronous::transport::AsyncTransport;
use crate::error::Result;
use crate::PacketSerializer;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::StreamExt;
Expand All @@ -27,7 +28,11 @@ pub struct WebsocketTransport {

impl WebsocketTransport {
/// Creates a new instance over a request that might hold additional headers and an URL.
pub async fn new(base_url: Url, headers: Option<HeaderMap>) -> Result<Self> {
pub async fn new(
base_url: Url,
headers: Option<HeaderMap>,
serializer: Arc<PacketSerializer>,
) -> Result<Self> {
let mut url = base_url;
url.query_pairs_mut().append_pair("transport", "websocket");
url.set_scheme("ws").unwrap();
Expand All @@ -41,7 +46,7 @@ impl WebsocketTransport {
let (ws_stream, _) = connect_async(req).await?;
let (sen, rec) = ws_stream.split();

let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
let inner = AsyncWebsocketGeneralTransport::new(sen, rec, serializer).await;
Ok(WebsocketTransport {
inner,
base_url: Arc::new(RwLock::new(url)),
Expand Down Expand Up @@ -118,7 +123,12 @@ mod test {
let url = crate::test::engine_io_server()?.to_string()
+ "engine.io/?EIO="
+ &ENGINE_IO_VERSION.to_string();
WebsocketTransport::new(Url::from_str(&url[..])?, None).await
WebsocketTransport::new(
Url::from_str(&url[..])?,
None,
PacketSerializer::default_arc(),
)
.await
}

#[tokio::test]
Expand Down
24 changes: 16 additions & 8 deletions engineio/src/asynchronous/async_transports/websocket_general.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Cow, str::from_utf8, sync::Arc, task::Poll};

use crate::{error::Result, Error, Packet, PacketId};
use crate::{error::Result, Error, Packet, PacketId, PacketSerializer};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::{
ready,
Expand All @@ -22,16 +22,19 @@ type AsyncWebsocketReceiver = SplitStream<WebSocketStream<MaybeTlsStream<TcpStre
pub(crate) struct AsyncWebsocketGeneralTransport {
sender: Arc<Mutex<AsyncWebsocketSender>>,
receiver: Arc<Mutex<AsyncWebsocketReceiver>>,
serializer: Arc<PacketSerializer>,
}

impl AsyncWebsocketGeneralTransport {
pub(crate) async fn new(
sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
serializer: Arc<PacketSerializer>,
) -> Self {
AsyncWebsocketGeneralTransport {
sender: Arc::new(Mutex::new(sender)),
receiver: Arc::new(Mutex::new(receiver)),
serializer,
}
}

Expand All @@ -41,25 +44,30 @@ impl AsyncWebsocketGeneralTransport {
let mut receiver = self.receiver.lock().await;
let mut sender = self.sender.lock().await;

let ping_packet = Packet::new(PacketId::Ping, Bytes::from("probe"));
let ping_packet = self.serializer.encode(ping_packet);

sender
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
Packet::new(PacketId::Ping, Bytes::from("probe")),
))?)))
.send(Message::text(Cow::Borrowed(from_utf8(&ping_packet)?)))
.await?;

let msg = receiver
.next()
.await
.ok_or(Error::IllegalWebsocketUpgrade())??;

if msg.into_data() != Bytes::from(Packet::new(PacketId::Pong, Bytes::from("probe"))) {
let pong_packet = Packet::new(PacketId::Pong, Bytes::from("probe"));
let pong_packet = self.serializer.encode(pong_packet);

if msg.into_data() != pong_packet {
return Err(Error::InvalidPacket());
}

let upgrade_packet = Packet::new(PacketId::Upgrade, Bytes::from(""));
let upgrade_packet = self.serializer.encode(upgrade_packet);

sender
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
Packet::new(PacketId::Upgrade, Bytes::from("")),
))?)))
.send(Message::text(Cow::Borrowed(from_utf8(&upgrade_packet)?)))
.await?;

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use crate::asynchronous::transport::AsyncTransport;
use crate::error::Result;
use crate::PacketSerializer;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::Stream;
Expand Down Expand Up @@ -34,6 +35,7 @@ impl WebsocketSecureTransport {
base_url: Url,
tls_config: Option<TlsConnector>,
headers: Option<HeaderMap>,
serializer: Arc<PacketSerializer>,
) -> Result<Self> {
let mut url = base_url;
url.query_pairs_mut().append_pair("transport", "websocket");
Expand Down Expand Up @@ -61,7 +63,7 @@ impl WebsocketSecureTransport {
.await?;

let (sen, rec) = ws_stream.split();
let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
let inner = AsyncWebsocketGeneralTransport::new(sen, rec, serializer).await;

Ok(WebsocketSecureTransport {
inner,
Expand Down Expand Up @@ -143,6 +145,7 @@ mod test {
Url::from_str(&url[..])?,
Some(crate::test::tls_connector()?),
None,
PacketSerializer::default_arc(),
)
.await
}
Expand Down
28 changes: 23 additions & 5 deletions engineio/src/asynchronous/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::{
asynchronous::{
async_socket::Socket as InnerSocket,
Expand All @@ -7,7 +9,7 @@ use crate::{
},
error::Result,
header::HeaderMap,
packet::HandshakePacket,
packet::{HandshakePacket, PacketSerializer},
Error, Packet, ENGINE_IO_VERSION,
};
use bytes::Bytes;
Expand All @@ -22,6 +24,7 @@ pub struct ClientBuilder {
url: Url,
tls_config: Option<TlsConnector>,
headers: Option<HeaderMap>,
serializer: Arc<PacketSerializer>,
handshake: Option<HandshakePacket>,
on_error: OptionalCallback<String>,
on_open: OptionalCallback<()>,
Expand All @@ -45,6 +48,7 @@ impl ClientBuilder {
headers: None,
tls_config: None,
handshake: None,
serializer: Arc::new(PacketSerializer::default()),
on_close: OptionalCallback::default(),
on_data: OptionalCallback::default(),
on_error: OptionalCallback::default(),
Expand All @@ -53,6 +57,13 @@ impl ClientBuilder {
}
}

/// Specify Packet Serializer
pub fn packet_serializer(mut self, packet_serializer: Arc<PacketSerializer>) -> Self {
self.serializer = packet_serializer;

self
}

/// Specify transport's tls config
pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
self.tls_config = Some(tls_config);
Expand Down Expand Up @@ -127,9 +138,10 @@ impl ClientBuilder {

let mut url = self.url.clone();

let handshake: HandshakePacket =
Packet::try_from(transport.next().await.ok_or(Error::IncompletePacket())??)?
.try_into()?;
let handshake: HandshakePacket = self
.serializer
.decode(transport.next().await.ok_or(Error::IncompletePacket())??)?
.try_into()?;

// update the base_url with the new sid
url.query_pairs_mut().append_pair("sid", &handshake.sid[..]);
Expand Down Expand Up @@ -184,6 +196,7 @@ impl ClientBuilder {
// SAFETY: handshake function called previously.
Ok(Client::new(InnerSocket::new(
transport.into(),
self.serializer,
self.handshake.unwrap(),
self.on_close,
self.on_data,
Expand Down Expand Up @@ -214,7 +227,9 @@ impl ClientBuilder {

match self.url.scheme() {
"http" | "ws" => {
let mut transport = WebsocketTransport::new(self.url.clone(), headers).await?;
let mut transport =
WebsocketTransport::new(self.url.clone(), headers, self.serializer.clone())
.await?;

if self.handshake.is_some() {
transport.upgrade().await?;
Expand All @@ -225,6 +240,7 @@ impl ClientBuilder {
// SAFETY: handshake function called previously.
Ok(Client::new(InnerSocket::new(
transport.into(),
self.serializer,
self.handshake.unwrap(),
self.on_close,
self.on_data,
Expand All @@ -238,6 +254,7 @@ impl ClientBuilder {
self.url.clone(),
self.tls_config.clone(),
headers,
self.serializer.clone(),
)
.await?;

Expand All @@ -250,6 +267,7 @@ impl ClientBuilder {
// SAFETY: handshake function called previously.
Ok(Client::new(InnerSocket::new(
transport.into(),
self.serializer,
self.handshake.unwrap(),
self.on_close,
self.on_data,
Expand Down
Loading
Loading