diff --git a/core/src/dealer/mod.rs b/core/src/dealer/mod.rs index 53cddba0..bca1ec20 100644 --- a/core/src/dealer/mod.rs +++ b/core/src/dealer/mod.rs @@ -1,5 +1,5 @@ mod maps; -mod protocol; +pub mod protocol; use std::iter; use std::pin::Pin; @@ -11,6 +11,7 @@ use std::time::Duration; use futures_core::{Future, Stream}; use futures_util::future::join_all; use futures_util::{SinkExt, StreamExt}; +use thiserror::Error; use tokio::select; use tokio::sync::mpsc::{self, UnboundedReceiver}; use tokio::sync::Semaphore; @@ -21,7 +22,6 @@ use url::Url; use self::maps::*; use self::protocol::*; -pub use self::protocol::{Message, Request}; use crate::socket; use crate::util::{keep_flushing, CancelOnDrop, TimeoutOnDrop}; @@ -29,6 +29,13 @@ type WsMessage = tungstenite::Message; type WsError = tungstenite::Error; type WsResult = Result; +const WEBSOCKET_CLOSE_TIMEOUT: Duration = Duration::from_secs(3); + +const PING_INTERVAL: Duration = Duration::from_secs(30); +const PING_TIMEOUT: Duration = Duration::from_secs(3); + +const RECONNECT_INTERVAL: Duration = Duration::from_secs(10); + pub struct Response { pub success: bool, } @@ -64,8 +71,8 @@ impl Responder { } } - pub fn send(mut self, success: Response) { - self.send_internal(success); + pub fn send(mut self, response: Response) { + self.send_internal(response); self.sent = true; } @@ -105,26 +112,26 @@ where impl RequestHandler for F where - F: (Fn(Request) -> R) + Send + Sync + 'static, + F: (Fn(Request) -> R) + Send + 'static, R: IntoResponse, { - fn handle_request(&self, request: Request, responder: Responder) { + fn handle_request(&self, request: Request, responder: Responder) { self(request).respond(responder); } } -pub trait RequestHandler: Send + Sync + 'static { - fn handle_request(&self, request: Request, responder: Responder); +pub trait RequestHandler: Send + 'static { + fn handle_request(&self, request: Request, responder: Responder); } -type MessageHandler = mpsc::UnboundedSender>; +type MessageHandler = mpsc::UnboundedSender; // TODO: Maybe it's possible to unregister subscription directly when they // are dropped instead of on next failed attempt. -pub struct Subscription(UnboundedReceiver>); +pub struct Subscription(UnboundedReceiver); impl Stream for Subscription { - type Item = Message; + type Item = Message; fn poll_next( mut self: Pin<&mut Self>, @@ -153,25 +160,25 @@ fn split_uri(s: &str) -> Option> { Some(iter::once(scheme).chain(split)) } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Error)] pub enum AddHandlerError { + #[error("There is already a handler for the given uri")] AlreadyHandled, + #[error("The specified uri is invalid")] InvalidUri, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Error)] pub enum SubscriptionError { + #[error("The specified uri is invalid")] InvalidUri, } -fn add_handler( +fn add_handler( map: &mut HandlerMap>, uri: &str, - handler: H, -) -> Result<(), AddHandlerError> -where - H: RequestHandler, -{ + handler: impl RequestHandler, +) -> Result<(), AddHandlerError> { let split = split_uri(uri).ok_or(AddHandlerError::InvalidUri)?; map.insert(split, Box::new(handler)) .map_err(|_| AddHandlerError::AlreadyHandled) @@ -218,7 +225,7 @@ macro_rules! create_dealer { Dealer { shared, - handle: TimeoutOnDrop::new(handle, Duration::from_secs(3)), + handle: TimeoutOnDrop::new(handle, WEBSOCKET_CLOSE_TIMEOUT), } } } @@ -278,7 +285,7 @@ struct DealerShared { } impl DealerShared { - fn dispatch_message(&self, msg: Message) { + fn dispatch_message(&self, msg: Message) { if let Some(split) = split_uri(&msg.uri) { self.message_handlers .lock() @@ -287,11 +294,7 @@ impl DealerShared { } } - fn dispatch_request( - &self, - request: Request, - send_tx: &mpsc::UnboundedSender, - ) { + fn dispatch_request(&self, request: Request, send_tx: &mpsc::UnboundedSender) { // ResponseSender will automatically send "success: false" if it is dropped without an answer. let responder = Responder::new(request.key.clone(), send_tx.clone()); @@ -490,7 +493,7 @@ async fn connect( let ping_task = async { use tokio::time::{interval, sleep}; - let mut timer = interval(Duration::from_secs(30)); + let mut timer = interval(PING_INTERVAL); loop { timer.tick().await; @@ -503,7 +506,7 @@ async fn connect( debug!("Sent ping"); - sleep(Duration::from_secs(3)).await; + sleep(PING_TIMEOUT).await; if !pong_received.load(atomic::Ordering::SeqCst) { // No response @@ -539,7 +542,7 @@ async fn run( Fut: Future + Send + 'static, F: (FnMut() -> Fut) + Send + 'static, { - let init_task = |t| Some(TimeoutOnDrop::new(t, Duration::from_secs(3))); + let init_task = |t| Some(TimeoutOnDrop::new(t, WEBSOCKET_CLOSE_TIMEOUT)); let mut tasks = if let Some((s, r)) = initial_tasks { (init_task(s), init_task(r)) @@ -574,6 +577,7 @@ async fn run( Ok((s, r)) => tasks = (init_task(s), init_task(r)), Err(e) => { warn!("Error while connecting: {}", e); + tokio::time::sleep(RECONNECT_INTERVAL).await; } } } diff --git a/core/src/dealer/protocol.rs b/core/src/dealer/protocol.rs index cb0a1835..9e62a2e5 100644 --- a/core/src/dealer/protocol.rs +++ b/core/src/dealer/protocol.rs @@ -5,15 +5,6 @@ use serde::Deserialize; pub type JsonValue = serde_json::Value; pub type JsonObject = serde_json::Map; -#[derive(Clone, Debug, Deserialize)] -pub struct Request

{ - #[serde(default)] - pub headers: HashMap, - pub message_ident: String, - pub key: String, - pub payload: P, -} - #[derive(Clone, Debug, Deserialize)] pub struct Payload { pub message_id: i32, @@ -22,18 +13,27 @@ pub struct Payload { } #[derive(Clone, Debug, Deserialize)] -pub struct Message

{ +pub struct Request { + #[serde(default)] + pub headers: HashMap, + pub message_ident: String, + pub key: String, + pub payload: Payload, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Message { #[serde(default)] pub headers: HashMap, pub method: Option, #[serde(default)] - pub payloads: Vec

, + pub payloads: Vec, pub uri: String, } #[derive(Clone, Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] -pub enum MessageOrRequest { - Message(Message), - Request(Request), +pub(super) enum MessageOrRequest { + Message(Message), + Request(Request), }