Skip to content

Commit

Permalink
feat: transport channel pool (#562)
Browse files Browse the repository at this point in the history
1. Set DEFAULT_TTL_MS: to 600 * 1000;
2. don't print whole payload on debug log
3. Use `ProviderRef` in `WASM` environment to avoid ownership issues (may cause undefined here)
4. Use request_internal instead of request to avoid multi-serde from JS and WASM
5. Implementation of round robin pool for data channel
  • Loading branch information
RyanKung authored Mar 28, 2024
1 parent 3ffcdd2 commit 05b1739
Show file tree
Hide file tree
Showing 14 changed files with 321 additions and 56 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rings-node = { path = "crates/node" }
rings-rpc = { path = "crates/rpc", default-features = false }
rings-snark = { path = "crates/snark", default-features = false }
rings-transport = { path = "crates/transport" }
serde-wasm-bindgen = "0.6.1"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-bindgen-macro-support = "0.2.84"
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/consts.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Constant variables.
///
/// default ttl in ms
pub const DEFAULT_TTL_MS: u64 = 300 * 1000;
pub const DEFAULT_TTL_MS: u64 = 600 * 1000;
pub const MAX_TTL_MS: u64 = DEFAULT_TTL_MS * 10;
pub const TS_OFFSET_TOLERANCE_MS: u128 = 3000;
pub const DEFAULT_SESSION_TTL_MS: u64 = 30 * 24 * 3600 * 1000;
Expand Down
17 changes: 11 additions & 6 deletions crates/node/src/backend/snark/browser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::backend::types::snark::SNARKProofTask;
use crate::backend::types::snark::SNARKVerifyTask;
use crate::backend::BackendMessageHandlerDynObj;
use crate::prelude::rings_core::utils::js_value;
use crate::provider::browser::ProviderRef;

/// We need this ref to pass Task ref to js_sys
#[wasm_bindgen]
Expand Down Expand Up @@ -131,15 +132,15 @@ impl SNARKBehaviour {
/// Handle js native message
pub fn handle_snark_task_message(
self,
provider: Provider,
provider: ProviderRef,
ctx: JsValue,
msg: JsValue,
) -> js_sys::Promise {
let ins = self.clone();
future_to_promise(async move {
let ctx = js_value::deserialize::<MessagePayload>(ctx)?;
let msg = js_value::deserialize::<SNARKTaskMessage>(msg)?;
ins.handle_message(provider.into(), &ctx, &msg)
ins.handle_message(provider.inner(), &ctx, &msg)
.await
.map_err(|e| Error::BackendError(e.to_string()))?;
Ok(JsValue::NULL)
Expand Down Expand Up @@ -168,14 +169,18 @@ impl SNARKBehaviour {
/// send proof task to did
pub fn send_proof_task_to(
&self,
provider: Provider,
provider: ProviderRef,
task: SNARKProofTaskRef,
did: String,
) -> js_sys::Promise {
let ins = self.clone();
future_to_promise(async move {
let ret = ins
.send_proof_task(provider.clone().into(), task.as_ref(), Did::from_str(&did)?)
.send_proof_task(
provider.inner().clone(),
task.as_ref(),
Did::from_str(&did)?,
)
.await
.map_err(JsError::from)?;
Ok(JsValue::from(ret))
Expand All @@ -185,14 +190,14 @@ impl SNARKBehaviour {
/// Generate a proof task and send it to did
pub fn gen_and_send_proof_task_to(
&self,
provider: Provider,
provider: ProviderRef,
circuits: Vec<Circuit>,
did: String,
) -> js_sys::Promise {
let ins = self.clone();
future_to_promise(async move {
let ret = ins
.gen_and_send_proof_task(provider.clone().into(), circuits, Did::from_str(&did)?)
.gen_and_send_proof_task(provider.inner().clone(), circuits, Did::from_str(&did)?)
.await
.map_err(JsError::from)?;
Ok(JsValue::from(ret))
Expand Down
18 changes: 6 additions & 12 deletions crates/node/src/backend/snark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -939,18 +939,12 @@ impl MessageHandler<SNARKTaskMessage> for SNARKBehaviour {
}
.into();
let params = resp.into_send_backend_message_request(verifier)?;
#[cfg(not(target_arch = "wasm32"))]
provider.request(Method::SendBackendMessage, params).await?;
#[cfg(target_arch = "wasm32")]
{
let req = rings_core::utils::js_value::serialize(&params)?;
let promise = provider.request(Method::SendBackendMessage.to_string(), req);
wasm_bindgen_futures::JsFuture::from(promise)
.await
.map_err(|e| {
Error::JsError(format!("Failed send backend message: {:?}", e))
})?;
}
provider
.request_internal(
Method::SendBackendMessage.to_string(),
serde_json::to_value(params)?,
)
.await?;
Ok(())
}
SNARKTask::SNARKVerify(t) => {
Expand Down
24 changes: 24 additions & 0 deletions crates/node/src/provider/browser/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,30 @@ pub enum AddressType {
Ed25519,
}

/// A wrapper of Arc Ref of Provider
#[derive(Clone)]
#[wasm_export]
pub struct ProviderRef {
inner: Arc<Provider>,
}

impl ProviderRef {
/// get wrapped arc, this is useful for wasm case
pub fn inner(&self) -> Arc<Provider> {
self.inner.clone()
}
}

#[wasm_export]
impl Provider {
/// make provider as an As arc ref
pub fn as_ref(&self) -> ProviderRef {
ProviderRef {
inner: Arc::new(self.clone()),
}
}
}

#[wasm_export]
impl Provider {
/// Create new instance of Provider, return Promise
Expand Down
2 changes: 1 addition & 1 deletion crates/node/src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl Provider {
method: String,
params: serde_json::Value,
) -> Result<serde_json::Value> {
tracing::debug!("request {} params: {:?}", method, params);
tracing::debug!("request {}", method);
self.handler
.handle_request(self.processor.clone(), method, params)
.await
Expand Down
7 changes: 5 additions & 2 deletions crates/node/src/tests/wasm/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ async fn test_send_snark_backend_message() {
console_log!("wait for register");
js_utils::window_sleep(1000).await.unwrap();
console_log!("gen snark task and send");
let promise =
snark_behaviour.gen_and_send_proof_task_to(provider1, circuits, provider2.address());
let promise = snark_behaviour.gen_and_send_proof_task_to(
provider1.as_ref(),
circuits,
provider2.address(),
);
wasm_bindgen_futures::JsFuture::from(promise).await.unwrap();
}
3 changes: 3 additions & 0 deletions crates/transport/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ web-sys = { version = "0.3.64", optional = true, features = [
"RtcSessionDescription",
"RtcSessionDescriptionInit",
"RtcStatsReport",
"Window",
"WorkerGlobalScope",
"ServiceWorkerGlobalScope",
] }

# Common dependencies
Expand Down
67 changes: 50 additions & 17 deletions crates/transport/src/connections/native_webrtc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ use webrtc::peer_connection::RTCPeerConnection;
use crate::callback::InnerTransportCallback;
use crate::connection_ref::ConnectionRef;
use crate::core::callback::BoxedTransportCallback;
use crate::core::pool::MessageSenderPool;
use crate::core::pool::RoundRobin;
use crate::core::pool::RoundRobinPool;
use crate::core::pool::StatusPool;
use crate::core::transport::ConnectionInterface;
use crate::core::transport::TransportInterface;
use crate::core::transport::TransportMessage;
Expand All @@ -31,12 +35,35 @@ use crate::pool::Pool;

const WEBRTC_WAIT_FOR_DATA_CHANNEL_OPEN_TIMEOUT: u8 = 8; // seconds
const WEBRTC_GATHER_TIMEOUT: u8 = 60; // seconds
/// pool size of data channel
const DATA_CHANNEL_POOL_SIZE: u8 = 4;

#[cfg_attr(arch_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(arch_family = "wasm"), async_trait)]
impl MessageSenderPool<Arc<RTCDataChannel>> for RoundRobinPool<Arc<RTCDataChannel>> {
type Message = TransportMessage;
async fn send(&self, msg: TransportMessage) -> Result<()> {
let channel = self.select()?;
let data = bincode::serialize(&msg).map(Bytes::from)?;
if let Err(e) = channel.send(&data).await {
tracing::error!("{:?}, Data size: {:?}", e, data.len());
return Err(e.into());
}
Ok(())
}
}

impl StatusPool<Arc<RTCDataChannel>> for RoundRobinPool<Arc<RTCDataChannel>> {
fn all_ready(&self) -> Result<bool> {
self.all(|c| c.ready_state() == RTCDataChannelState::Open)
}
}

/// A connection that implemented by webrtc-rs library.
/// Used for native environment.
pub struct WebrtcConnection {
webrtc_conn: RTCPeerConnection,
webrtc_data_channel: Arc<RTCDataChannel>,
webrtc_data_channel: Arc<RoundRobinPool<Arc<RTCDataChannel>>>,
webrtc_data_channel_state_notifier: Notifier,
cancel_token: CancellationToken,
}
Expand All @@ -52,7 +79,7 @@ pub struct WebrtcTransport {
impl WebrtcConnection {
fn new(
webrtc_conn: RTCPeerConnection,
webrtc_data_channel: Arc<RTCDataChannel>,
webrtc_data_channel: Arc<RoundRobinPool<Arc<RTCDataChannel>>>,
webrtc_data_channel_state_notifier: Notifier,
) -> Self {
Self {
Expand Down Expand Up @@ -108,12 +135,7 @@ impl ConnectionInterface for WebrtcConnection {

async fn send_message(&self, msg: TransportMessage) -> Result<()> {
self.webrtc_wait_for_data_channel_open().await?;
let data = bincode::serialize(&msg).map(Bytes::from)?;
if let Err(e) = self.webrtc_data_channel.send(&data).await {
tracing::error!("{:?}, Data size: {:?}", e, data.len());
return Err(e.into());
}
Ok(())
self.webrtc_data_channel.send(msg).await
}

async fn get_stats(&self) -> Vec<String> {
Expand Down Expand Up @@ -171,17 +193,15 @@ impl ConnectionInterface for WebrtcConnection {
return Err(Error::DataChannelOpen("Connection unavailable".to_string()));
}

if self.webrtc_data_channel.ready_state() == RTCDataChannelState::Open {
if self.webrtc_data_channel.all_ready()? {
return Ok(());
}

self.webrtc_data_channel_state_notifier
.set_timeout(WEBRTC_WAIT_FOR_DATA_CHANNEL_OPEN_TIMEOUT);
self.webrtc_data_channel_state_notifier.clone().await;

dbg!(self.webrtc_data_channel.ready_state());

if self.webrtc_data_channel.ready_state() == RTCDataChannelState::Open {
if self.webrtc_data_channel.all_ready()? {
return Ok(());
} else {
return Err(Error::DataChannelOpen(format!(
Expand Down Expand Up @@ -239,7 +259,7 @@ impl TransportInterface for WebrtcTransport {
//
// Create webrtc connection
//
let webrtc_conn = webrtc_api.new_peer_connection(webrtc_config).await?;
let webrtc_conn: RTCPeerConnection = webrtc_api.new_peer_connection(webrtc_config).await?;

//
// Set callbacks
Expand All @@ -251,15 +271,23 @@ impl TransportInterface for WebrtcTransport {
webrtc_data_channel_state_notifier.clone(),
));

let channel_pool = Arc::new(RoundRobinPool::default());
let channel_pool_ref = channel_pool.clone();
let data_channel_inner_cb = inner_cb.clone();
webrtc_conn.on_data_channel(Box::new(move |d: Arc<RTCDataChannel>| {
let d_label = d.label();
let d_id = d.id();
tracing::debug!("New DataChannel {d_label} {d_id}");

let channel_pool = channel_pool_ref.clone();
let on_open_inner_cb = data_channel_inner_cb.clone();
d.on_open(Box::new(move || {
Box::pin(async move { on_open_inner_cb.on_data_channel_open().await })
Box::pin(async move {
// check all channels are ready
// trigger on_data_channel_open callback iff all channels ready (open)
if let Ok(true) = channel_pool.all_ready() {
on_open_inner_cb.on_data_channel_open().await
}
})
}));

let on_close_inner_cb = data_channel_inner_cb.clone();
Expand Down Expand Up @@ -300,14 +328,19 @@ impl TransportInterface for WebrtcTransport {
//
// Create data channel
//
let webrtc_data_channel = webrtc_conn.create_data_channel("rings", None).await?;
for i in 0..DATA_CHANNEL_POOL_SIZE {
let ch = webrtc_conn
.create_data_channel(&format!("rings_data_channel_{}", i), None)
.await?;
channel_pool.push(ch)?;
}

//
// Construct the Connection
//
let conn = WebrtcConnection::new(
webrtc_conn,
webrtc_data_channel,
channel_pool,
webrtc_data_channel_state_notifier,
);

Expand Down
Loading

0 comments on commit 05b1739

Please sign in to comment.