diff --git a/audio/src/fetch.rs b/audio/src/fetch.rs index 5fdf9e74..da343096 100644 --- a/audio/src/fetch.rs +++ b/audio/src/fetch.rs @@ -1,11 +1,8 @@ use std::cmp::{max, min}; use std::fs; -use std::future::Future; use std::io::{self, Read, Seek, SeekFrom, Write}; -use std::pin::Pin; use std::sync::atomic::{self, AtomicUsize}; use std::sync::{Arc, Condvar, Mutex}; -use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; @@ -236,7 +233,7 @@ struct AudioFileDownloadStatus { downloaded: RangeSet, } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, PartialEq, Eq)] enum DownloadStrategy { RandomAccess(), Streaming(), @@ -249,7 +246,6 @@ struct AudioFileShared { cond: Condvar, download_status: Mutex, download_strategy: Mutex, - number_of_open_requests: AtomicUsize, ping_time_ms: AtomicUsize, read_position: AtomicUsize, } @@ -358,7 +354,6 @@ impl AudioFileStreaming { downloaded: RangeSet::new(), }), download_strategy: Mutex::new(DownloadStrategy::RandomAccess()), // start with random access mode until someone tells us otherwise - number_of_open_requests: AtomicUsize::new(0), ping_time_ms: AtomicUsize::new(0), read_position: AtomicUsize::new(0), }); @@ -373,7 +368,7 @@ impl AudioFileStreaming { let (stream_loader_command_tx, stream_loader_command_rx) = mpsc::unbounded_channel::(); - let fetcher = AudioFileFetch::new( + session.spawn(audio_file_fetch( session.clone(), shared.clone(), initial_data_rx, @@ -382,9 +377,8 @@ impl AudioFileStreaming { write_file, stream_loader_command_rx, complete_tx, - ); + )); - session.spawn(fetcher); Ok(AudioFileStreaming { read_file, position: 0, @@ -442,17 +436,11 @@ async fn audio_file_fetch_receive_data( initial_data_offset: usize, initial_request_length: usize, request_sent_time: Instant, + mut measure_ping_time: bool, + finish_tx: mpsc::UnboundedSender<()>, ) { let mut data_offset = initial_data_offset; let mut request_length = initial_request_length; - let mut measure_ping_time = shared - .number_of_open_requests - .load(atomic::Ordering::SeqCst) - == 0; - - shared - .number_of_open_requests - .fetch_add(1, atomic::Ordering::SeqCst); let result = loop { let data = match data_rx.next().await { @@ -501,9 +489,7 @@ async fn audio_file_fetch_receive_data( shared.cond.notify_all(); } - shared - .number_of_open_requests - .fetch_sub(1, atomic::Ordering::SeqCst); + let _ = finish_tx.send(()); if result.is_err() { warn!( @@ -517,162 +503,6 @@ async fn audio_file_fetch_receive_data( ); } } -/* -async fn audio_file_fetch( - session: Session, - shared: Arc, - initial_data_rx: ChannelData, - initial_request_sent_time: Instant, - initial_data_length: usize, - - output: NamedTempFile, - stream_loader_command_rx: mpsc::UnboundedReceiver, - complete_tx: oneshot::Sender, -) { - let (file_data_tx, file_data_rx) = unbounded::(); - - let requested_range = Range::new(0, initial_data_length); - let mut download_status = shared.download_status.lock().unwrap(); - download_status.requested.add_range(&requested_range); - - session.spawn(audio_file_fetch_receive_data( - shared.clone(), - file_data_tx.clone(), - initial_data_rx, - 0, - initial_data_length, - initial_request_sent_time, - )); - - let mut network_response_times_ms: Vec::new(); - - let f1 = file_data_rx.map(|x| Ok::<_, ()>(x)).try_for_each(|x| { - match x { - ReceivedData::ResponseTimeMs(response_time_ms) => { - trace!("Ping time estimated as: {} ms.", response_time_ms); - - // record the response time - network_response_times_ms.push(response_time_ms); - - // prune old response times. Keep at most three. - while network_response_times_ms.len() > 3 { - network_response_times_ms.remove(0); - } - - // stats::median is experimental. So we calculate the median of up to three ourselves. - let ping_time_ms: usize = match network_response_times_ms.len() { - 1 => network_response_times_ms[0] as usize, - 2 => { - ((network_response_times_ms[0] + network_response_times_ms[1]) / 2) as usize - } - 3 => { - let mut times = network_response_times_ms.clone(); - times.sort(); - times[1] - } - _ => unreachable!(), - }; - - // store our new estimate for everyone to see - shared - .ping_time_ms - .store(ping_time_ms, atomic::Ordering::Relaxed); - } - ReceivedData::Data(data) => { - output - .as_mut() - .unwrap() - .seek(SeekFrom::Start(data.offset as u64)) - .unwrap(); - output - .as_mut() - .unwrap() - .write_all(data.data.as_ref()) - .unwrap(); - - let mut full = false; - - { - let mut download_status = shared.download_status.lock().unwrap(); - - let received_range = Range::new(data.offset, data.data.len()); - download_status.downloaded.add_range(&received_range); - shared.cond.notify_all(); - - if download_status.downloaded.contained_length_from_value(0) - >= shared.file_size - { - full = true; - } - - drop(download_status); - } - - if full { - self.finish(); - return future::ready(Err(())); - } - } - } - future::ready(Ok(())) - }); - - let f2 = stream_loader_command_rx.map(Ok::<_, ()>).try_for_each(|x| { - match cmd { - StreamLoaderCommand::Fetch(request) => { - self.download_range(request.start, request.length); - } - StreamLoaderCommand::RandomAccessMode() => { - *(shared.download_strategy.lock().unwrap()) = DownloadStrategy::RandomAccess(); - } - StreamLoaderCommand::StreamMode() => { - *(shared.download_strategy.lock().unwrap()) = DownloadStrategy::Streaming(); - } - StreamLoaderCommand::Close() => return future::ready(Err(())), - } - Ok(()) - }); - - let f3 = future::poll_fn(|_| { - if let DownloadStrategy::Streaming() = self.get_download_strategy() { - let number_of_open_requests = shared - .number_of_open_requests - .load(atomic::Ordering::SeqCst); - let max_requests_to_send = - MAX_PREFETCH_REQUESTS - min(MAX_PREFETCH_REQUESTS, number_of_open_requests); - - if max_requests_to_send > 0 { - let bytes_pending: usize = { - let download_status = shared.download_status.lock().unwrap(); - download_status - .requested - .minus(&download_status.downloaded) - .len() - }; - - let ping_time_seconds = - 0.001 * shared.ping_time_ms.load(atomic::Ordering::Relaxed) as f64; - let download_rate = session.channel().get_download_rate_estimate(); - - let desired_pending_bytes = max( - (PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * shared.stream_data_rate as f64) - as usize, - (FAST_PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * download_rate as f64) - as usize, - ); - - if bytes_pending < desired_pending_bytes { - self.pre_fetch_more_data( - desired_pending_bytes - bytes_pending, - max_requests_to_send, - ); - } - } - } - Poll::Pending - }); - future::select_all(vec![f1, f2, f3]).await -}*/ struct AudioFileFetch { session: Session, @@ -680,54 +510,21 @@ struct AudioFileFetch { output: Option, file_data_tx: mpsc::UnboundedSender, - file_data_rx: mpsc::UnboundedReceiver, - - stream_loader_command_rx: mpsc::UnboundedReceiver, complete_tx: Option>, network_response_times_ms: Vec, + number_of_open_requests: usize, + + download_finish_tx: mpsc::UnboundedSender<()>, +} + +// Might be replaced by enum from std once stable +#[derive(PartialEq, Eq)] +enum ControlFlow { + Break, + Continue, } impl AudioFileFetch { - fn new( - session: Session, - shared: Arc, - initial_data_rx: ChannelData, - initial_request_sent_time: Instant, - initial_data_length: usize, - - output: NamedTempFile, - stream_loader_command_rx: mpsc::UnboundedReceiver, - complete_tx: oneshot::Sender, - ) -> AudioFileFetch { - let (file_data_tx, file_data_rx) = mpsc::unbounded_channel::(); - - { - let requested_range = Range::new(0, initial_data_length); - let mut download_status = shared.download_status.lock().unwrap(); - download_status.requested.add_range(&requested_range); - } - - session.spawn(audio_file_fetch_receive_data( - shared.clone(), - file_data_tx.clone(), - initial_data_rx, - 0, - initial_data_length, - initial_request_sent_time, - )); - - AudioFileFetch { - session, - shared, - output: Some(output), - file_data_tx, - file_data_rx, - stream_loader_command_rx, - complete_tx: Some(complete_tx), - network_response_times_ms: Vec::new(), - } - } - fn get_download_strategy(&mut self) -> DownloadStrategy { *(self.shared.download_strategy.lock().unwrap()) } @@ -785,7 +582,11 @@ impl AudioFileFetch { range.start, range.length, Instant::now(), + self.number_of_open_requests == 0, + self.download_finish_tx.clone(), )); + + self.number_of_open_requests += 1; } } @@ -833,103 +634,86 @@ impl AudioFileFetch { } } - fn poll_file_data_rx(&mut self, cx: &mut Context<'_>) -> Poll<()> { - loop { - match self.file_data_rx.poll_recv(cx) { - Poll::Ready(None) => return Poll::Ready(()), - Poll::Ready(Some(ReceivedData::ResponseTimeMs(response_time_ms))) => { - trace!("Ping time estimated as: {} ms.", response_time_ms); + fn handle_file_data(&mut self, data: ReceivedData) -> ControlFlow { + match data { + ReceivedData::ResponseTimeMs(response_time_ms) => { + trace!("Ping time estimated as: {} ms.", response_time_ms); - // record the response time - self.network_response_times_ms.push(response_time_ms); + // record the response time + self.network_response_times_ms.push(response_time_ms); - // prune old response times. Keep at most three. - while self.network_response_times_ms.len() > 3 { - self.network_response_times_ms.remove(0); - } - - // stats::median is experimental. So we calculate the median of up to three ourselves. - let ping_time_ms: usize = match self.network_response_times_ms.len() { - 1 => self.network_response_times_ms[0] as usize, - 2 => { - ((self.network_response_times_ms[0] - + self.network_response_times_ms[1]) - / 2) as usize - } - 3 => { - let mut times = self.network_response_times_ms.clone(); - times.sort_unstable(); - times[1] - } - _ => unreachable!(), - }; - - // store our new estimate for everyone to see - self.shared - .ping_time_ms - .store(ping_time_ms, atomic::Ordering::Relaxed); + // prune old response times. Keep at most three. + while self.network_response_times_ms.len() > 3 { + self.network_response_times_ms.remove(0); } - Poll::Ready(Some(ReceivedData::Data(data))) => { - self.output - .as_mut() - .unwrap() - .seek(SeekFrom::Start(data.offset as u64)) - .unwrap(); - self.output - .as_mut() - .unwrap() - .write_all(data.data.as_ref()) - .unwrap(); - let mut full = false; - - { - let mut download_status = self.shared.download_status.lock().unwrap(); - - let received_range = Range::new(data.offset, data.data.len()); - download_status.downloaded.add_range(&received_range); - self.shared.cond.notify_all(); - - if download_status.downloaded.contained_length_from_value(0) - >= self.shared.file_size - { - full = true; - } - - drop(download_status); + // stats::median is experimental. So we calculate the median of up to three ourselves. + let ping_time_ms: usize = match self.network_response_times_ms.len() { + 1 => self.network_response_times_ms[0] as usize, + 2 => { + ((self.network_response_times_ms[0] + self.network_response_times_ms[1]) + / 2) as usize } - - if full { - self.finish(); - return Poll::Ready(()); + 3 => { + let mut times = self.network_response_times_ms.clone(); + times.sort_unstable(); + times[1] } + _ => unreachable!(), + }; + + // store our new estimate for everyone to see + self.shared + .ping_time_ms + .store(ping_time_ms, atomic::Ordering::Relaxed); + } + ReceivedData::Data(data) => { + self.output + .as_mut() + .unwrap() + .seek(SeekFrom::Start(data.offset as u64)) + .unwrap(); + self.output + .as_mut() + .unwrap() + .write_all(data.data.as_ref()) + .unwrap(); + + let mut download_status = self.shared.download_status.lock().unwrap(); + + let received_range = Range::new(data.offset, data.data.len()); + download_status.downloaded.add_range(&received_range); + self.shared.cond.notify_all(); + + let full = download_status.downloaded.contained_length_from_value(0) + >= self.shared.file_size; + + drop(download_status); + + if full { + self.finish(); + return ControlFlow::Break; } - Poll::Pending => return Poll::Pending, } } + ControlFlow::Continue } - fn poll_stream_loader_command_rx(&mut self, cx: &mut Context<'_>) -> Poll<()> { - loop { - match self.stream_loader_command_rx.poll_recv(cx) { - Poll::Ready(None) => return Poll::Ready(()), - Poll::Ready(Some(cmd)) => match cmd { - StreamLoaderCommand::Fetch(request) => { - self.download_range(request.start, request.length); - } - StreamLoaderCommand::RandomAccessMode() => { - *(self.shared.download_strategy.lock().unwrap()) = - DownloadStrategy::RandomAccess(); - } - StreamLoaderCommand::StreamMode() => { - *(self.shared.download_strategy.lock().unwrap()) = - DownloadStrategy::Streaming(); - } - StreamLoaderCommand::Close() => return Poll::Ready(()), - }, - Poll::Pending => return Poll::Pending, + fn handle_stream_loader_command(&mut self, cmd: StreamLoaderCommand) -> ControlFlow { + match cmd { + StreamLoaderCommand::Fetch(request) => { + self.download_range(request.start, request.length); } + StreamLoaderCommand::RandomAccessMode() => { + *(self.shared.download_strategy.lock().unwrap()) = DownloadStrategy::RandomAccess(); + } + StreamLoaderCommand::StreamMode() => { + *(self.shared.download_strategy.lock().unwrap()) = DownloadStrategy::Streaming(); + self.trigger_preload(); + } + StreamLoaderCommand::Close() => return ControlFlow::Break, } + ControlFlow::Continue } fn finish(&mut self) { @@ -939,57 +723,102 @@ impl AudioFileFetch { output.seek(SeekFrom::Start(0)).unwrap(); let _ = complete_tx.send(output); } + + fn trigger_preload(&mut self) { + if self.number_of_open_requests >= MAX_PREFETCH_REQUESTS { + return; + } + + let max_requests_to_send = MAX_PREFETCH_REQUESTS - self.number_of_open_requests; + + let bytes_pending: usize = { + let download_status = self.shared.download_status.lock().unwrap(); + download_status + .requested + .minus(&download_status.downloaded) + .len() + }; + + let ping_time_seconds = + 0.001 * self.shared.ping_time_ms.load(atomic::Ordering::Relaxed) as f64; + let download_rate = self.session.channel().get_download_rate_estimate(); + + let desired_pending_bytes = max( + (PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * self.shared.stream_data_rate as f64) + as usize, + (FAST_PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * download_rate as f64) as usize, + ); + + if bytes_pending < desired_pending_bytes { + self.pre_fetch_more_data(desired_pending_bytes - bytes_pending, max_requests_to_send); + } + } } -impl Future for AudioFileFetch { - type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - if let Poll::Ready(()) = self.poll_stream_loader_command_rx(cx) { - return Poll::Ready(()); - } +async fn audio_file_fetch( + session: Session, + shared: Arc, + initial_data_rx: ChannelData, + initial_request_sent_time: Instant, + initial_data_length: usize, - if let Poll::Ready(()) = self.poll_file_data_rx(cx) { - return Poll::Ready(()); - } + output: NamedTempFile, + mut stream_loader_command_rx: mpsc::UnboundedReceiver, + complete_tx: oneshot::Sender, +) { + let (file_data_tx, mut file_data_rx) = mpsc::unbounded_channel(); + let (download_finish_tx, mut download_finish_rx) = mpsc::unbounded_channel(); - if let DownloadStrategy::Streaming() = self.get_download_strategy() { - let number_of_open_requests = self - .shared - .number_of_open_requests - .load(atomic::Ordering::SeqCst); - let max_requests_to_send = - MAX_PREFETCH_REQUESTS - min(MAX_PREFETCH_REQUESTS, number_of_open_requests); + { + let requested_range = Range::new(0, initial_data_length); + let mut download_status = shared.download_status.lock().unwrap(); + download_status.requested.add_range(&requested_range); + } - if max_requests_to_send > 0 { - let bytes_pending: usize = { - let download_status = self.shared.download_status.lock().unwrap(); - download_status - .requested - .minus(&download_status.downloaded) - .len() - }; + session.spawn(audio_file_fetch_receive_data( + shared.clone(), + file_data_tx.clone(), + initial_data_rx, + 0, + initial_data_length, + initial_request_sent_time, + true, + download_finish_tx.clone(), + )); - let ping_time_seconds = - 0.001 * self.shared.ping_time_ms.load(atomic::Ordering::Relaxed) as f64; - let download_rate = self.session.channel().get_download_rate_estimate(); + let mut fetch = AudioFileFetch { + session, + shared, + output: Some(output), - let desired_pending_bytes = max( - (PREFETCH_THRESHOLD_FACTOR - * ping_time_seconds - * self.shared.stream_data_rate as f64) as usize, - (FAST_PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * download_rate as f64) - as usize, - ); + file_data_tx, + complete_tx: Some(complete_tx), + network_response_times_ms: Vec::new(), + number_of_open_requests: 1, - if bytes_pending < desired_pending_bytes { - self.pre_fetch_more_data( - desired_pending_bytes - bytes_pending, - max_requests_to_send, - ); + download_finish_tx, + }; + + loop { + tokio::select! { + cmd = stream_loader_command_rx.recv() => { + if cmd.map_or(true, |cmd| fetch.handle_stream_loader_command(cmd) == ControlFlow::Break) { + break; + } + }, + data = file_data_rx.recv() => { + if data.map_or(true, |data| fetch.handle_file_data(data) == ControlFlow::Break) { + break; + } + }, + _ = download_finish_rx.recv() => { + fetch.number_of_open_requests -= 1; + + if fetch.get_download_strategy() == DownloadStrategy::Streaming() { + fetch.trigger_preload(); } } } - Poll::Pending } }