diff --git a/library/bytesio/src/bytesio.rs b/library/bytesio/src/bytesio.rs index 4ec66250..cd454679 100644 --- a/library/bytesio/src/bytesio.rs +++ b/library/bytesio/src/bytesio.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use std::time::Duration; +use std::time::Duration; use async_trait::async_trait; use bytes::BufMut; use bytes::Bytes; @@ -24,6 +24,11 @@ pub trait TNetIO: Send + Sync { async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>; async fn read(&mut self) -> Result; async fn read_timeout(&mut self, duration: Duration) -> Result; + async fn read_min_bytes_with_timeout( + &mut self, + duration: Duration, + size: usize + ) -> Result; fn get_net_type(&self) -> NetType; } @@ -86,6 +91,20 @@ impl TNetIO for UdpIO { } } + // As udp is based on packet, we can't read part of them, so the size parameter is not effective + async fn read_min_bytes_with_timeout( + &mut self, + duration: Duration, + _size: usize + ) -> Result { + match tokio::time::timeout(duration, self.read()).await { + Ok(data) => data, + Err(err) => Err(BytesIOError { + value: BytesIOErrorValue::TimeoutError(err), + }), + } + } + async fn read(&mut self) -> Result { let mut buf = vec![0; 4096]; let len = self.socket.recv(&mut buf).await?; @@ -131,6 +150,50 @@ impl TNetIO for TcpIO { } } + async fn read_min_bytes_with_timeout( + &mut self, + duration: Duration, + size: usize + ) -> Result { + let mut result: BytesMut = BytesMut::new(); + let start_time = tokio::time::Instant::now(); + loop { + let current_time = tokio::time::Instant::now(); + let remaining_duration = match duration.checked_sub(current_time - start_time) { + Some(remaining) => remaining, + None => Duration::from_secs(0), + }; + let message_result = tokio::time::timeout(remaining_duration, self.stream.next()).await; + match message_result { + Ok(message) => match message { + Some(data) => match data { + Ok(bytes) => { + result.extend_from_slice(&bytes); + if result.len() >= size { + return Ok(result); + } + } + Err(err) => { + return Err(BytesIOError { + value: BytesIOErrorValue::IOError(err), + }) + } + }, + None => { + return Err(BytesIOError { + value: BytesIOErrorValue::NoneReturn, + }) + } + }, + Err(err) => { + return Err(BytesIOError { + value: BytesIOErrorValue::TimeoutError(err), + }) + } + } + } + } + async fn read(&mut self) -> Result { let message = self.stream.next().await; diff --git a/protocol/rtsp/src/session/mod.rs b/protocol/rtsp/src/session/mod.rs index b3fa44f5..ee8c26aa 100644 --- a/protocol/rtsp/src/session/mod.rs +++ b/protocol/rtsp/src/session/mod.rs @@ -48,6 +48,7 @@ use define::rtsp_method_name; use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; use tokio::sync::mpsc; use commonlib::auth::Auth; @@ -150,7 +151,12 @@ impl RtspServerSession { match data { Some(a) => { if self.reader.len() < a.length as usize { - let data = self.io.lock().await.read().await?; + let data = self + .io + .lock() + .await + .read_min_bytes_with_timeout(Duration::from_millis(1000), a.length.into()) + .await?; self.reader.extend_from_slice(&data[..]); } self.on_rtp_over_rtsp_message(a.channel_identifier, a.length as usize)