diff --git a/Cargo.lock b/Cargo.lock index 6c0a6fd2..1f97d578 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -918,6 +918,15 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "input_buffer" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413" +dependencies = [ + "bytes", +] + [[package]] name = "instant" version = "0.1.9" @@ -1229,6 +1238,7 @@ dependencies = [ "thiserror", "tokio", "tokio-stream", + "tokio-tungstenite", "tokio-util", "url", "uuid", @@ -1911,6 +1921,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin", + "untrusted", + "web-sys", + "winapi", +] + [[package]] name = "rodio" version = "0.14.0" @@ -1945,6 +1970,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustls" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" +dependencies = [ + "base64", + "log", + "ring", + "sct", + "webpki", +] + [[package]] name = "ryu" version = "1.0.5" @@ -1966,6 +2004,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sct" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "sdl2" version = "0.34.5" @@ -2103,6 +2151,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "stdweb" version = "0.1.3" @@ -2275,6 +2329,17 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls", + "tokio", + "webpki", +] + [[package]] name = "tokio-stream" version = "0.1.5" @@ -2286,6 +2351,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e96bb520beab540ab664bd5a9cfeaa1fcd846fa68c830b42e2c8963071251d2" +dependencies = [ + "futures-util", + "log", + "pin-project", + "rustls", + "tokio", + "tokio-rustls", + "tungstenite", + "webpki", + "webpki-roots", +] + [[package]] name = "tokio-util" version = "0.6.6" @@ -2341,6 +2423,29 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "tungstenite" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093" +dependencies = [ + "base64", + "byteorder", + "bytes", + "http", + "httparse", + "input_buffer", + "log", + "rand", + "rustls", + "sha-1", + "thiserror", + "url", + "utf-8", + "webpki", + "webpki-roots", +] + [[package]] name = "typenum" version = "1.13.0" @@ -2389,6 +2494,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "url" version = "2.2.2" @@ -2401,6 +2512,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "0.8.2" @@ -2561,6 +2678,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "webpki-roots" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940" +dependencies = [ + "webpki", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/core/Cargo.toml b/core/Cargo.toml index 80db5687..8ed21273 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -39,8 +39,9 @@ serde_json = "1.0" sha-1 = "0.9" shannon = "0.2.0" thiserror = "1.0.7" -tokio = { version = "1.0", features = ["io-util", "net", "rt", "sync"] } +tokio = { version = "1.5", features = ["io-util", "macros", "net", "rt", "time", "sync"] } tokio-stream = "0.1.1" +tokio-tungstenite = { version = "0.14", default-features = false, features = ["rustls-tls"] } tokio-util = { version = "0.6", features = ["codec"] } url = "2.1" uuid = { version = "0.8", default-features = false, features = ["v4"] } diff --git a/core/src/apresolve.rs b/core/src/apresolve.rs index b11e275f..8dced22d 100644 --- a/core/src/apresolve.rs +++ b/core/src/apresolve.rs @@ -1,12 +1,12 @@ use std::error::Error; use hyper::client::HttpConnector; -use hyper::{Body, Client, Method, Request, Uri}; +use hyper::{Body, Client, Method, Request}; use hyper_proxy::{Intercept, Proxy, ProxyConnector}; use serde::Deserialize; use url::Url; -use super::AP_FALLBACK; +use super::ap_fallback; const APRESOLVE_ENDPOINT: &str = "http://apresolve.spotify.com:80"; @@ -18,7 +18,7 @@ struct ApResolveData { async fn try_apresolve( proxy: Option<&Url>, ap_port: Option, -) -> Result> { +) -> Result<(String, u16), Box> { let port = ap_port.unwrap_or(443); let mut req = Request::new(Body::empty()); @@ -43,27 +43,29 @@ async fn try_apresolve( let body = hyper::body::to_bytes(response.into_body()).await?; let data: ApResolveData = serde_json::from_slice(body.as_ref())?; + let mut aps = data.ap_list.into_iter().filter_map(|ap| { + let mut split = ap.rsplitn(2, ':'); + let port = split + .next() + .expect("rsplitn should not return empty iterator"); + let host = split.next()?.to_owned(); + let port: u16 = port.parse().ok()?; + Some((host, port)) + }); let ap = if ap_port.is_some() || proxy.is_some() { - data.ap_list.into_iter().find_map(|ap| { - if ap.parse::().ok()?.port()? == port { - Some(ap) - } else { - None - } - }) + aps.find(|(_, p)| *p == port) } else { - data.ap_list.into_iter().next() + aps.next() } - .ok_or("empty AP List")?; + .ok_or("no valid AP in list")?; Ok(ap) } -pub async fn apresolve(proxy: Option<&Url>, ap_port: Option) -> String { +pub async fn apresolve(proxy: Option<&Url>, ap_port: Option) -> (String, u16) { try_apresolve(proxy, ap_port).await.unwrap_or_else(|e| { - warn!("Failed to resolve Access Point: {}", e); - warn!("Using fallback \"{}\"", AP_FALLBACK); - AP_FALLBACK.into() + warn!("Failed to resolve Access Point: {}, using fallback.", e); + ap_fallback() }) } diff --git a/core/src/connection/mod.rs b/core/src/connection/mod.rs index 58d3e83a..bacdc653 100644 --- a/core/src/connection/mod.rs +++ b/core/src/connection/mod.rs @@ -5,7 +5,6 @@ pub use self::codec::ApCodec; pub use self::handshake::handshake; use std::io::{self, ErrorKind}; -use std::net::ToSocketAddrs; use futures_util::{SinkExt, StreamExt}; use protobuf::{self, Message, ProtobufError}; @@ -16,7 +15,6 @@ use url::Url; use crate::authentication::Credentials; use crate::protocol::keyexchange::{APLoginFailed, ErrorCode}; -use crate::proxytunnel; use crate::version; pub type Transport = Framed; @@ -58,50 +56,8 @@ impl From for AuthenticationError { } } -pub async fn connect(addr: String, proxy: Option<&Url>) -> io::Result { - let socket = if let Some(proxy_url) = proxy { - info!("Using proxy \"{}\"", proxy_url); - - let socket_addr = proxy_url.socket_addrs(|| None).and_then(|addrs| { - addrs.into_iter().next().ok_or_else(|| { - io::Error::new( - io::ErrorKind::NotFound, - "Can't resolve proxy server address", - ) - }) - })?; - let socket = TcpStream::connect(&socket_addr).await?; - - let uri = addr.parse::().map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidData, - "Can't parse access point address", - ) - })?; - let host = uri.host().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "The access point address contains no hostname", - ) - })?; - let port = uri.port().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "The access point address contains no port", - ) - })?; - - proxytunnel::proxy_connect(socket, host, port.as_str()).await? - } else { - let socket_addr = addr.to_socket_addrs()?.next().ok_or_else(|| { - io::Error::new( - io::ErrorKind::NotFound, - "Can't resolve access point address", - ) - })?; - - TcpStream::connect(&socket_addr).await? - }; +pub async fn connect(host: &str, port: u16, proxy: Option<&Url>) -> io::Result { + let socket = crate::socket::connect(host, port, proxy).await?; handshake(socket).await } diff --git a/core/src/dealer/maps.rs b/core/src/dealer/maps.rs new file mode 100644 index 00000000..38916e40 --- /dev/null +++ b/core/src/dealer/maps.rs @@ -0,0 +1,117 @@ +use std::collections::HashMap; + +#[derive(Debug)] +pub struct AlreadyHandledError(()); + +pub enum HandlerMap { + Leaf(T), + Branch(HashMap>), +} + +impl Default for HandlerMap { + fn default() -> Self { + Self::Branch(HashMap::new()) + } +} + +impl HandlerMap { + pub fn insert<'a>( + &mut self, + mut path: impl Iterator, + handler: T, + ) -> Result<(), AlreadyHandledError> { + match self { + Self::Leaf(_) => Err(AlreadyHandledError(())), + Self::Branch(children) => { + if let Some(component) = path.next() { + let node = children.entry(component.to_owned()).or_default(); + node.insert(path, handler) + } else if children.is_empty() { + *self = Self::Leaf(handler); + Ok(()) + } else { + Err(AlreadyHandledError(())) + } + } + } + } + + pub fn get<'a>(&self, mut path: impl Iterator) -> Option<&T> { + match self { + Self::Leaf(t) => Some(t), + Self::Branch(m) => { + let component = path.next()?; + m.get(component)?.get(path) + } + } + } + + pub fn remove<'a>(&mut self, mut path: impl Iterator) -> Option { + match self { + Self::Leaf(_) => match std::mem::take(self) { + Self::Leaf(t) => Some(t), + _ => unreachable!(), + }, + Self::Branch(map) => { + let component = path.next()?; + let next = map.get_mut(component)?; + let result = next.remove(path); + match &*next { + Self::Branch(b) if b.is_empty() => { + map.remove(component); + } + _ => (), + } + result + } + } + } +} + +pub struct SubscriberMap { + subscribed: Vec, + children: HashMap>, +} + +impl Default for SubscriberMap { + fn default() -> Self { + Self { + subscribed: Vec::new(), + children: HashMap::new(), + } + } +} + +impl SubscriberMap { + pub fn insert<'a>(&mut self, mut path: impl Iterator, handler: T) { + if let Some(component) = path.next() { + self.children + .entry(component.to_owned()) + .or_default() + .insert(path, handler); + } else { + self.subscribed.push(handler); + } + } + + pub fn is_empty(&self) -> bool { + self.children.is_empty() && self.subscribed.is_empty() + } + + pub fn retain<'a>( + &mut self, + mut path: impl Iterator, + fun: &mut impl FnMut(&T) -> bool, + ) { + self.subscribed.retain(|x| fun(x)); + + if let Some(next) = path.next() { + if let Some(y) = self.children.get_mut(next) { + y.retain(path, fun); + if y.is_empty() { + self.children.remove(next); + } + } + } + } +} diff --git a/core/src/dealer/mod.rs b/core/src/dealer/mod.rs new file mode 100644 index 00000000..53cddba0 --- /dev/null +++ b/core/src/dealer/mod.rs @@ -0,0 +1,586 @@ +mod maps; +mod protocol; + +use std::iter; +use std::pin::Pin; +use std::sync::atomic::AtomicBool; +use std::sync::{atomic, Arc, Mutex}; +use std::task::Poll; +use std::time::Duration; + +use futures_core::{Future, Stream}; +use futures_util::future::join_all; +use futures_util::{SinkExt, StreamExt}; +use tokio::select; +use tokio::sync::mpsc::{self, UnboundedReceiver}; +use tokio::sync::Semaphore; +use tokio::task::JoinHandle; +use tokio_tungstenite::tungstenite; +use tungstenite::error::UrlError; +use url::Url; + +use self::maps::*; +use self::protocol::*; +pub use self::protocol::{Message, Request}; +use crate::socket; +use crate::util::{keep_flushing, CancelOnDrop, TimeoutOnDrop}; + +type WsMessage = tungstenite::Message; +type WsError = tungstenite::Error; +type WsResult = Result; + +pub struct Response { + pub success: bool, +} + +pub struct Responder { + key: String, + tx: mpsc::UnboundedSender, + sent: bool, +} + +impl Responder { + fn new(key: String, tx: mpsc::UnboundedSender) -> Self { + Self { + key, + tx, + sent: false, + } + } + + // Should only be called once + fn send_internal(&mut self, response: Response) { + let response = serde_json::json!({ + "type": "reply", + "key": &self.key, + "payload": { + "success": response.success, + } + }) + .to_string(); + + if let Err(e) = self.tx.send(WsMessage::Text(response)) { + warn!("Wasn't able to reply to dealer request: {}", e); + } + } + + pub fn send(mut self, success: Response) { + self.send_internal(success); + self.sent = true; + } + + pub fn force_unanswered(mut self) { + self.sent = true; + } +} + +impl Drop for Responder { + fn drop(&mut self) { + if !self.sent { + self.send_internal(Response { success: false }); + } + } +} + +pub trait IntoResponse { + fn respond(self, responder: Responder); +} + +impl IntoResponse for Response { + fn respond(self, responder: Responder) { + responder.send(self) + } +} + +impl IntoResponse for F +where + F: Future + Send + 'static, +{ + fn respond(self, responder: Responder) { + tokio::spawn(async move { + responder.send(self.await); + }); + } +} + +impl RequestHandler for F +where + F: (Fn(Request) -> R) + Send + Sync + 'static, + R: IntoResponse, +{ + fn handle_request(&self, request: Request, responder: Responder) { + self(request).respond(responder); + } +} + +pub trait RequestHandler: Send + Sync + 'static { + fn handle_request(&self, request: Request, responder: Responder); +} + +type MessageHandler = mpsc::UnboundedSender>; + +// TODO: Maybe it's possible to unregister subscription directly when they +// are dropped instead of on next failed attempt. +pub struct Subscription(UnboundedReceiver>); + +impl Stream for Subscription { + type Item = Message; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.0.poll_recv(cx) + } +} + +fn split_uri(s: &str) -> Option> { + let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") { + ("hm", '/', rest) + } else if let Some(rest) = s.strip_suffix("spotify:") { + ("spotify", ':', rest) + } else { + return None; + }; + + let rest = rest.trim_end_matches(sep); + let mut split = rest.split(sep); + + if rest.is_empty() { + assert_eq!(split.next(), Some("")); + } + + Some(iter::once(scheme).chain(split)) +} + +#[derive(Debug, Clone)] +pub enum AddHandlerError { + AlreadyHandled, + InvalidUri, +} + +#[derive(Debug, Clone)] +pub enum SubscriptionError { + InvalidUri, +} + +fn add_handler( + map: &mut HandlerMap>, + uri: &str, + handler: H, +) -> Result<(), AddHandlerError> +where + H: RequestHandler, +{ + let split = split_uri(uri).ok_or(AddHandlerError::InvalidUri)?; + map.insert(split, Box::new(handler)) + .map_err(|_| AddHandlerError::AlreadyHandled) +} + +fn remove_handler(map: &mut HandlerMap, uri: &str) -> Option { + map.remove(split_uri(uri)?) +} + +fn subscribe( + map: &mut SubscriberMap, + uris: &[&str], +) -> Result { + let (tx, rx) = mpsc::unbounded_channel(); + + for &uri in uris { + let split = split_uri(uri).ok_or(SubscriptionError::InvalidUri)?; + map.insert(split, tx.clone()); + } + + Ok(Subscription(rx)) +} + +#[derive(Default)] +pub struct Builder { + message_handlers: SubscriberMap, + request_handlers: HandlerMap>, +} + +macro_rules! create_dealer { + ($builder:expr, $shared:ident -> $body:expr) => { + match $builder { + builder => { + let shared = Arc::new(DealerShared { + message_handlers: Mutex::new(builder.message_handlers), + request_handlers: Mutex::new(builder.request_handlers), + notify_drop: Semaphore::new(0), + }); + + let handle = { + let $shared = Arc::clone(&shared); + tokio::spawn($body) + }; + + Dealer { + shared, + handle: TimeoutOnDrop::new(handle, Duration::from_secs(3)), + } + } + } + }; +} + +impl Builder { + pub fn new() -> Self { + Self::default() + } + + pub fn add_handler( + &mut self, + uri: &str, + handler: impl RequestHandler, + ) -> Result<(), AddHandlerError> { + add_handler(&mut self.request_handlers, uri, handler) + } + + pub fn subscribe(&mut self, uris: &[&str]) -> Result { + subscribe(&mut self.message_handlers, uris) + } + + pub fn launch_in_background(self, get_url: F, proxy: Option) -> Dealer + where + Fut: Future + Send + 'static, + F: (FnMut() -> Fut) + Send + 'static, + { + create_dealer!(self, shared -> run(shared, None, get_url, proxy)) + } + + pub async fn launch(self, mut get_url: F, proxy: Option) -> WsResult + where + Fut: Future + Send + 'static, + F: (FnMut() -> Fut) + Send + 'static, + { + let dealer = create_dealer!(self, shared -> { + // Try to connect. + let url = get_url().await; + let tasks = connect(&url, proxy.as_ref(), &shared).await?; + + // If a connection is established, continue in a background task. + run(shared, Some(tasks), get_url, proxy) + }); + + Ok(dealer) + } +} + +struct DealerShared { + message_handlers: Mutex>, + request_handlers: Mutex>>, + + // Semaphore with 0 permits. By closing this semaphore, we indicate + // that the actual Dealer struct has been dropped. + notify_drop: Semaphore, +} + +impl DealerShared { + fn dispatch_message(&self, msg: Message) { + if let Some(split) = split_uri(&msg.uri) { + self.message_handlers + .lock() + .unwrap() + .retain(split, &mut |tx| tx.send(msg.clone()).is_ok()); + } + } + + fn dispatch_request( + &self, + request: Request, + send_tx: &mpsc::UnboundedSender, + ) { + // ResponseSender will automatically send "success: false" if it is dropped without an answer. + let responder = Responder::new(request.key.clone(), send_tx.clone()); + + let split = if let Some(split) = split_uri(&request.message_ident) { + split + } else { + warn!( + "Dealer request with invalid message_ident: {}", + &request.message_ident + ); + return; + }; + + { + let handler_map = self.request_handlers.lock().unwrap(); + + if let Some(handler) = handler_map.get(split) { + handler.handle_request(request, responder); + return; + } + } + + warn!("No handler for message_ident: {}", &request.message_ident); + } + + fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender) { + match m { + MessageOrRequest::Message(m) => self.dispatch_message(m), + MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx), + } + } + + async fn closed(&self) { + self.notify_drop.acquire().await.unwrap_err(); + } + + fn is_closed(&self) -> bool { + self.notify_drop.is_closed() + } +} + +pub struct Dealer { + shared: Arc, + handle: TimeoutOnDrop<()>, +} + +impl Dealer { + pub fn add_handler(&self, uri: &str, handler: H) -> Result<(), AddHandlerError> + where + H: RequestHandler, + { + add_handler( + &mut self.shared.request_handlers.lock().unwrap(), + uri, + handler, + ) + } + + pub fn remove_handler(&self, uri: &str) -> Option> { + remove_handler(&mut self.shared.request_handlers.lock().unwrap(), uri) + } + + pub fn subscribe(&self, uris: &[&str]) -> Result { + subscribe(&mut self.shared.message_handlers.lock().unwrap(), uris) + } + + pub async fn close(mut self) { + debug!("closing dealer"); + + self.shared.notify_drop.close(); + + if let Some(handle) = self.handle.take() { + CancelOnDrop(handle).await.unwrap(); + } + } +} + +/// Initializes a connection and returns futures that will finish when the connection is closed/lost. +async fn connect( + address: &Url, + proxy: Option<&Url>, + shared: &Arc, +) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> { + let host = address + .host_str() + .ok_or(WsError::Url(UrlError::NoHostName))?; + + let default_port = match address.scheme() { + "ws" => 80, + "wss" => 443, + _ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme)), + }; + + let port = address.port().unwrap_or(default_port); + + let stream = socket::connect(host, port, proxy).await?; + + let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address, stream) + .await? + .0 + .split(); + + let (send_tx, mut send_rx) = mpsc::unbounded_channel::(); + + // Spawn a task that will forward messages from the channel to the websocket. + let send_task = { + let shared = Arc::clone(&shared); + + tokio::spawn(async move { + let result = loop { + select! { + biased; + () = shared.closed() => { + break Ok(None); + } + msg = send_rx.recv() => { + if let Some(msg) = msg { + // New message arrived through channel + if let WsMessage::Close(close_frame) = msg { + break Ok(close_frame); + } + + if let Err(e) = ws_tx.feed(msg).await { + break Err(e); + } + } else { + break Ok(None); + } + }, + e = keep_flushing(&mut ws_tx) => { + break Err(e) + } + } + }; + + send_rx.close(); + + // I don't trust in tokio_tungstenite's implementation of Sink::close. + let result = match result { + Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await, + Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await, + Err(e) => { + warn!("Dealer finished with an error: {}", e); + ws_tx.send(WsMessage::Close(None)).await + } + }; + + if let Err(e) = result { + warn!("Error while closing websocket: {}", e); + } + + debug!("Dropping send task"); + }) + }; + + let shared = Arc::clone(&shared); + + // A task that receives messages from the web socket. + let receive_task = tokio::spawn(async { + let pong_received = AtomicBool::new(true); + let send_tx = send_tx; + let shared = shared; + + let receive_task = async { + let mut ws_rx = ws_rx; + + loop { + match ws_rx.next().await { + Some(Ok(msg)) => match msg { + WsMessage::Text(t) => match serde_json::from_str(&t) { + Ok(m) => shared.dispatch(m, &send_tx), + Err(e) => info!("Received invalid message: {}", e), + }, + WsMessage::Binary(_) => { + info!("Received invalid binary message"); + } + WsMessage::Pong(_) => { + debug!("Received pong"); + pong_received.store(true, atomic::Ordering::Relaxed); + } + _ => (), // tungstenite handles Close and Ping automatically + }, + Some(Err(e)) => { + warn!("Websocket connection failed: {}", e); + break; + } + None => { + debug!("Websocket connection closed."); + break; + } + } + } + }; + + // Sends pings and checks whether a pong comes back. + let ping_task = async { + use tokio::time::{interval, sleep}; + + let mut timer = interval(Duration::from_secs(30)); + + loop { + timer.tick().await; + + pong_received.store(false, atomic::Ordering::Relaxed); + if send_tx.send(WsMessage::Ping(vec![])).is_err() { + // The sender is closed. + break; + } + + debug!("Sent ping"); + + sleep(Duration::from_secs(3)).await; + + if !pong_received.load(atomic::Ordering::SeqCst) { + // No response + warn!("Websocket peer does not respond."); + break; + } + } + }; + + // Exit this task as soon as one our subtasks fails. + // In both cases the connection is probably lost. + select! { + () = ping_task => (), + () = receive_task => () + } + + // Try to take send_task down with us, in case it's still alive. + let _ = send_tx.send(WsMessage::Close(None)); + + debug!("Dropping receive task"); + }); + + Ok((send_task, receive_task)) +} + +/// The main background task for `Dealer`, which coordinates reconnecting. +async fn run( + shared: Arc, + initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>, + mut get_url: F, + proxy: Option, +) where + Fut: Future + Send + 'static, + F: (FnMut() -> Fut) + Send + 'static, +{ + let init_task = |t| Some(TimeoutOnDrop::new(t, Duration::from_secs(3))); + + let mut tasks = if let Some((s, r)) = initial_tasks { + (init_task(s), init_task(r)) + } else { + (None, None) + }; + + while !shared.is_closed() { + match &mut tasks { + (Some(t0), Some(t1)) => { + select! { + () = shared.closed() => break, + r = t0 => { + r.unwrap(); // Whatever has gone wrong (probably panicked), we can't handle it, so let's panic too. + tasks.0.take(); + }, + r = t1 => { + r.unwrap(); + tasks.1.take(); + } + } + } + _ => { + let url = select! { + () = shared.closed() => { + break + }, + e = get_url() => e + }; + + match connect(&url, proxy.as_ref(), &shared).await { + Ok((s, r)) => tasks = (init_task(s), init_task(r)), + Err(e) => { + warn!("Error while connecting: {}", e); + } + } + } + } + } + + let tasks = tasks.0.into_iter().chain(tasks.1); + + let _ = join_all(tasks).await; +} diff --git a/core/src/dealer/protocol.rs b/core/src/dealer/protocol.rs new file mode 100644 index 00000000..cb0a1835 --- /dev/null +++ b/core/src/dealer/protocol.rs @@ -0,0 +1,39 @@ +use std::collections::HashMap; + +use serde::Deserialize; + +pub type JsonValue = serde_json::Value; +pub type JsonObject = serde_json::Map; + +#[derive(Clone, Debug, Deserialize)] +pub struct Request

{ + #[serde(default)] + pub headers: HashMap, + pub message_ident: String, + pub key: String, + pub payload: P, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Payload { + pub message_id: i32, + pub sent_by_device_id: String, + pub command: JsonObject, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Message

{ + #[serde(default)] + pub headers: HashMap, + pub method: Option, + #[serde(default)] + pub payloads: Vec

, + pub uri: String, +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum MessageOrRequest { + Message(Message), + Request(Request), +} diff --git a/core/src/lib.rs b/core/src/lib.rs index bb3e21d5..c6f6e190 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -14,25 +14,30 @@ pub mod cache; pub mod channel; pub mod config; mod connection; +#[allow(dead_code)] +mod dealer; #[doc(hidden)] pub mod diffie_hellman; pub mod keymaster; pub mod mercury; mod proxytunnel; pub mod session; +mod socket; pub mod spotify_id; #[doc(hidden)] pub mod util; pub mod version; -const AP_FALLBACK: &str = "ap.spotify.com:443"; +fn ap_fallback() -> (String, u16) { + (String::from("ap.spotify.com"), 443) +} #[cfg(feature = "apresolve")] mod apresolve; #[cfg(not(feature = "apresolve"))] mod apresolve { - pub async fn apresolve(_: Option<&url::Url>, _: Option) -> String { - return super::AP_FALLBACK.into(); + pub async fn apresolve(_: Option<&url::Url>, _: Option) -> (String, u16) { + super::ap_fallback() } } diff --git a/core/src/session.rs b/core/src/session.rs index 6c4abc54..f43a4cc0 100644 --- a/core/src/session.rs +++ b/core/src/session.rs @@ -69,8 +69,8 @@ impl Session { ) -> Result { let ap = apresolve(config.proxy.as_ref(), config.ap_port).await; - info!("Connecting to AP \"{}\"", ap); - let mut conn = connection::connect(ap, config.proxy.as_ref()).await?; + info!("Connecting to AP \"{}:{}\"", ap.0, ap.1); + let mut conn = connection::connect(&ap.0, ap.1, config.proxy.as_ref()).await?; let reusable_credentials = connection::authenticate(&mut conn, credentials, &config.device_id).await?; diff --git a/core/src/socket.rs b/core/src/socket.rs new file mode 100644 index 00000000..92274cc6 --- /dev/null +++ b/core/src/socket.rs @@ -0,0 +1,35 @@ +use std::io; +use std::net::ToSocketAddrs; + +use tokio::net::TcpStream; +use url::Url; + +use crate::proxytunnel; + +pub async fn connect(host: &str, port: u16, proxy: Option<&Url>) -> io::Result { + let socket = if let Some(proxy_url) = proxy { + info!("Using proxy \"{}\"", proxy_url); + + let socket_addr = proxy_url.socket_addrs(|| None).and_then(|addrs| { + addrs.into_iter().next().ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "Can't resolve proxy server address", + ) + }) + })?; + let socket = TcpStream::connect(&socket_addr).await?; + + proxytunnel::proxy_connect(socket, host, &port.to_string()).await? + } else { + let socket_addr = (host, port).to_socket_addrs()?.next().ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "Can't resolve access point address", + ) + })?; + + TcpStream::connect(&socket_addr).await? + }; + Ok(socket) +} diff --git a/core/src/util.rs b/core/src/util.rs index df9ea714..4f78c467 100644 --- a/core/src/util.rs +++ b/core/src/util.rs @@ -1,4 +1,99 @@ +use std::future::Future; use std::mem; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use futures_core::ready; +use futures_util::FutureExt; +use futures_util::Sink; +use futures_util::{future, SinkExt}; +use tokio::task::JoinHandle; +use tokio::time::timeout; + +/// Returns a future that will flush the sink, even if flushing is temporarily completed. +/// Finishes only if the sink throws an error. +pub(crate) fn keep_flushing<'a, T, S: Sink + Unpin + 'a>( + mut s: S, +) -> impl Future + 'a { + future::poll_fn(move |cx| match s.poll_flush_unpin(cx) { + Poll::Ready(Err(e)) => Poll::Ready(e), + _ => Poll::Pending, + }) +} + +pub struct CancelOnDrop(pub JoinHandle); + +impl Future for CancelOnDrop { + type Output = as Future>::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.poll_unpin(cx) + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + self.0.abort(); + } +} + +pub struct TimeoutOnDrop { + handle: Option>, + timeout: tokio::time::Duration, +} + +impl TimeoutOnDrop { + pub fn new(handle: JoinHandle, timeout: tokio::time::Duration) -> Self { + Self { + handle: Some(handle), + timeout, + } + } + + pub fn take(&mut self) -> Option> { + self.handle.take() + } +} + +impl Future for TimeoutOnDrop { + type Output = as Future>::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let r = ready!(self + .handle + .as_mut() + .expect("Polled after ready") + .poll_unpin(cx)); + self.handle = None; + Poll::Ready(r) + } +} + +impl Drop for TimeoutOnDrop { + fn drop(&mut self) { + let mut handle = if let Some(handle) = self.handle.take() { + handle + } else { + return; + }; + + if (&mut handle).now_or_never().is_some() { + // Already finished + return; + } + + match tokio::runtime::Handle::try_current() { + Ok(h) => { + h.spawn(timeout(self.timeout, CancelOnDrop(handle))); + } + Err(_) => { + // Not in tokio context, can't spawn + handle.abort(); + } + } + } +} pub trait Seq { fn next(&self) -> Self;