From f7994350f1f0450c419c2eee727ae35dfecc283a Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Sun, 4 Oct 2020 10:51:07 -0400 Subject: [PATCH 1/2] Fix flaky pub_sub test due to ordering issue of bind and connect --- tests/pub_sub.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/pub_sub.rs b/tests/pub_sub.rs index 3890d5a..459ff06 100644 --- a/tests/pub_sub.rs +++ b/tests/pub_sub.rs @@ -11,12 +11,14 @@ async fn test_pub_sub_sockets() { let cloned_payload = payload.clone(); let (server_stop_sender, mut server_stop) = oneshot::channel::<()>(); + let (has_bound_sender, has_bound) = oneshot::channel::<()>(); tokio::spawn(async move { let mut pub_socket = zeromq::PubSocket::new(); pub_socket .bind(bind_addr) .await .unwrap_or_else(|_| panic!("Failed to bind to {}", bind_addr)); + has_bound_sender.send(()).expect("channel was dropped"); loop { if let Ok(Some(_)) = server_stop.try_recv() { @@ -29,6 +31,10 @@ async fn test_pub_sub_sockets() { tokio::time::delay_for(Duration::from_millis(1)).await; } }); + // Block until the pub has finished binding + // TODO: ZMQ sockets should not care about this sort of ordering. + // See https://github.com/zeromq/zmq.rs/issues/73 + has_bound.await.expect("channel was cancelled"); let (sub_results_sender, sub_results) = mpsc::channel(100); for _ in 0..10 { From 83507603b32f8a2b6dd1b7841f4561ec8019164f Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Sun, 4 Oct 2020 11:03:10 -0400 Subject: [PATCH 2/2] Reorg endpoint module, improve host display&parse and add tests --- src/endpoint/host.rs | 185 +++++++++++++++++++++++++++ src/{endpoint.rs => endpoint/mod.rs} | 166 +++++------------------- src/endpoint/transport.rs | 41 ++++++ src/pub.rs | 2 +- src/sub.rs | 2 +- src/util.rs | 2 +- 6 files changed, 261 insertions(+), 137 deletions(-) create mode 100644 src/endpoint/host.rs rename src/{endpoint.rs => endpoint/mod.rs} (60%) create mode 100644 src/endpoint/transport.rs diff --git a/src/endpoint/host.rs b/src/endpoint/host.rs new file mode 100644 index 0000000..608eea8 --- /dev/null +++ b/src/endpoint/host.rs @@ -0,0 +1,185 @@ +use std::convert::TryFrom; +use std::fmt; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::str::FromStr; + +use super::EndpointError; +use crate::ZmqError; + +/// Represents a host address. Does not include the port, and may be either an +/// ip address or a domain name +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub enum Host { + /// An IPv4 address + Ipv4(Ipv4Addr), + /// An Ipv6 address + Ipv6(Ipv6Addr), + /// A domain name, such as `example.com` in `tcp://example.com:4567`. + Domain(String), +} + +impl fmt::Display for Host { + fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { + match self { + Host::Ipv4(addr) => write!(f, "{}", addr), + Host::Ipv6(addr) => write!(f, "{}", addr), + Host::Domain(name) => write!(f, "{}", name), + } + } +} + +impl TryFrom for IpAddr { + type Error = ZmqError; + + fn try_from(h: Host) -> Result { + match h { + Host::Ipv4(a) => Ok(IpAddr::V4(a)), + Host::Ipv6(a) => Ok(IpAddr::V6(a)), + Host::Domain(_) => Err(ZmqError::Other("Host was neither Ipv4 nor Ipv6")), + } + } +} + +impl From for Host { + fn from(a: IpAddr) -> Self { + match a { + IpAddr::V4(a) => Host::Ipv4(a), + IpAddr::V6(a) => Host::Ipv6(a), + } + } +} + +impl TryFrom for Host { + type Error = EndpointError; + + /// An Ipv6 address must be enclosed by `[` and `]`. + fn try_from(s: String) -> Result { + if s.is_empty() { + return Err(EndpointError::Syntax("Host string should not be empty")); + } + if let Ok(addr) = s.parse::() { + return Ok(Host::Ipv4(addr)); + } + + // Attempt to parse ipv6 from either ::1 or [::1] using ascii + let ipv6_substr = + if s.starts_with('[') && s.len() >= 4 && *s.as_bytes().last().unwrap() == b']' { + let substr = &s[1..s.len() - 1]; + debug_assert_eq!(substr.len(), s.len() - 2); + substr + } else { + &s + }; + if let Ok(addr) = ipv6_substr.parse::() { + return Ok(Host::Ipv6(addr)); + } + + Ok(Host::Domain(s)) + } +} + +impl FromStr for Host { + type Err = EndpointError; + + /// Equivalent to [`Self::try_from()`] + fn from_str(s: &str) -> Result { + let s = s.to_string(); + Self::try_from(s) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // These two tests on std are more for reference than any real test of + // functionality + #[test] + fn std_ipv6_parse() { + assert_eq!(Ipv6Addr::LOCALHOST, "::1".parse::().unwrap()); + assert!("[::1]".parse::().is_err()); + } + + #[test] + fn std_ipv6_display() { + assert_eq!("::1", &Ipv6Addr::LOCALHOST.to_string()); + } + + #[test] + fn parse_and_display_nobracket_ipv6_same_as_std() { + let valid_addr_strs = vec![ + "::1", + "::", + "2001:db8:a::123", + "2001:db8:0:0:0:0:2:1", + "2001:db8::2:1", + ]; + let invalid_addr_strs = vec!["", "[]", "[:]", ":"]; + + for valid in valid_addr_strs { + let parsed_std = valid.parse::().unwrap(); + let parsed_host = valid.parse::().unwrap(); + if let Host::Ipv6(parsed_host) = &parsed_host { + // Check that both are structurally the same + assert_eq!(&parsed_std, parsed_host); + } else { + panic!("Did not parse as IPV6!"); + } + // Check that both display as the same + assert_eq!(parsed_std.to_string(), parsed_host.to_string()); + } + + for invalid in invalid_addr_strs { + invalid.parse::().unwrap_err(); + let parsed_host = invalid.parse::(); + if parsed_host.is_err() { + continue; + } + let parsed_host = parsed_host.unwrap(); + if let Host::Domain(_) = parsed_host { + continue; + } + panic!( + "Expected that \"{}\" would not parse as Ipv6 or Ipv4, but instead it parsed as {:?}", + invalid, parsed_host + ); + } + } + + #[test] + fn parse_and_display_bracket_ipv6() { + let addr_strs = vec![ + "[::1]", + "[::]", + "[2001:db8:a::123]", + "[2001:db8:0:0:0:0:2:1]", + "[2001:db8::2:1]", + ]; + + fn remove_brackets(s: &str) -> &str { + assert!(s.starts_with('[')); + assert!(s.ends_with(']')); + let result = &s[1..s.len() - 1]; + assert_eq!(result.len(), s.len() - 2); + result + } + + for addr_str in addr_strs { + let parsed_host: Host = addr_str.parse().unwrap(); + assert!(addr_str.parse::().is_err()); + + if let Host::Ipv6(host_ipv6) = parsed_host { + assert_eq!( + host_ipv6, + remove_brackets(addr_str).parse::().unwrap() + ); + assert_eq!(parsed_host.to_string(), host_ipv6.to_string()); + } else { + panic!( + "Expected host to parse as Ipv6, but instead got {:?}", + parsed_host + ); + } + } + } +} diff --git a/src/endpoint.rs b/src/endpoint/mod.rs similarity index 60% rename from src/endpoint.rs rename to src/endpoint/mod.rs index 16a8204..4f3bd64 100644 --- a/src/endpoint.rs +++ b/src/endpoint/mod.rs @@ -1,124 +1,18 @@ -use crate::error::{EndpointError, ZmqError}; +mod host; +mod transport; + +pub use host::Host; +pub use transport::Transport; + use lazy_static::lazy_static; use regex::Regex; -use std::convert::{TryFrom, TryInto}; use std::fmt; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str::FromStr; -// TODO: Figure out better error types for this module. +use crate::error::EndpointError; pub type Port = u16; -/// Represents a host address. Does not include the port, and may be either an -/// ip address or a domain name -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub enum Host { - /// An IPv4 address - Ipv4(Ipv4Addr), - /// An Ipv6 address - Ipv6(Ipv6Addr), - /// A domain name, such as `example.com` in `tcp://example.com:4567`. - Domain(String), -} - -impl fmt::Display for Host { - fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { - match self { - Host::Ipv4(addr) => write!(f, "{}", addr), - Host::Ipv6(addr) => write!(f, "[{}]", addr), - Host::Domain(name) => write!(f, "{}", name), - } - } -} - -impl TryFrom for IpAddr { - type Error = ZmqError; - - fn try_from(h: Host) -> Result { - match h { - Host::Ipv4(a) => Ok(IpAddr::V4(a)), - Host::Ipv6(a) => Ok(IpAddr::V6(a)), - Host::Domain(_) => Err(ZmqError::Other("Host was neither Ipv4 nor Ipv6")), - } - } -} - -impl From for Host { - fn from(a: IpAddr) -> Self { - match a { - IpAddr::V4(a) => Host::Ipv4(a), - IpAddr::V6(a) => Host::Ipv6(a), - } - } -} - -impl TryFrom for Host { - type Error = EndpointError; - - /// An Ipv6 address must be enclosed by `[` and `]`. - fn try_from(s: String) -> Result { - if s.is_empty() { - return Err(EndpointError::Syntax("Host string should not be empty")); - } - if let Ok(addr) = s.parse::() { - return Ok(Host::Ipv4(addr)); - } - if s.len() >= 4 { - if let Ok(addr) = s[1..s.len() - 1].parse::() { - return Ok(Host::Ipv6(addr)); - } - } - Ok(Host::Domain(s)) - } -} - -impl FromStr for Host { - type Err = EndpointError; - - /// Equivalent to [`Self::try_from()`] - fn from_str(s: &str) -> Result { - let s = s.to_string(); - Self::try_from(s) - } -} - -/// The type of transport used by a given endpoint -#[derive(Debug, Clone, Hash, Copy, PartialEq, Eq)] -#[non_exhaustive] -pub enum Transport { - /// TCP transport - Tcp, -} - -impl FromStr for Transport { - type Err = EndpointError; - - fn from_str(s: &str) -> Result { - let result = match s { - "tcp" => Transport::Tcp, - _ => return Err(EndpointError::UnknownTransport(s.to_string())), - }; - Ok(result) - } -} -impl TryFrom<&str> for Transport { - type Error = EndpointError; - - fn try_from(s: &str) -> Result { - s.parse() - } -} - -impl fmt::Display for Transport { - fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { - let s = match self { - Transport::Tcp => "tcp", - }; - write!(f, "{}", s) - } -} - /// Represents a ZMQ Endpoint. /// /// # Examples @@ -190,44 +84,48 @@ impl FromStr for Endpoint { impl fmt::Display for Endpoint { fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { match self { - Endpoint::Tcp(host, port) => write!(f, "tcp://{}:{}", host, port), + Endpoint::Tcp(host, port) => { + if let Host::Ipv6(_) = host { + write!(f, "tcp://[{}]:{}", host, port) + } else { + write!(f, "tcp://{}:{}", host, port) + } + } } } } -// Trait aliases (https://github.com/rust-lang/rust/issues/41517) would make this unecessary -/// Any type that can be converted into an [`Endpoint`] should implement this -pub trait TryIntoEndpoint: Send { +/// Represents a type that can be converted into an [`Endpoint`]. +/// +/// This trait is intentionally sealed to prevent implementation on third-party +/// types. +// TODO: Is sealing this trait actually necessary? +pub trait TryIntoEndpoint: Send + private::Sealed { + /// Convert into an `Endpoint` via an owned `Self`. + /// + /// Enables efficient `Endpoint` -> `Endpoint` conversion, while permitting + /// the creation of a new `Endpoint` when given types like `&str`. fn try_into(self) -> Result; } - -impl TryIntoEndpoint for T -where - T: TryInto + Send, -{ - fn try_into(self) -> Result { - self.try_into() - } -} - impl TryIntoEndpoint for &str { fn try_into(self) -> Result { self.parse() } } - -impl TryIntoEndpoint for String { - fn try_into(self) -> Result { - self.parse() - } -} - impl TryIntoEndpoint for Endpoint { fn try_into(self) -> Result { Ok(self) } } +impl private::Sealed for str {} +impl private::Sealed for &str {} +impl private::Sealed for Endpoint {} + +mod private { + pub trait Sealed {} +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/endpoint/transport.rs b/src/endpoint/transport.rs new file mode 100644 index 0000000..9b000ad --- /dev/null +++ b/src/endpoint/transport.rs @@ -0,0 +1,41 @@ +use std::convert::TryFrom; +use std::fmt; +use std::str::FromStr; + +use super::EndpointError; + +/// The type of transport used by a given endpoint +#[derive(Debug, Clone, Hash, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum Transport { + /// TCP transport + Tcp, +} + +impl FromStr for Transport { + type Err = EndpointError; + + fn from_str(s: &str) -> Result { + let result = match s { + "tcp" => Transport::Tcp, + _ => return Err(EndpointError::UnknownTransport(s.to_string())), + }; + Ok(result) + } +} +impl TryFrom<&str> for Transport { + type Error = EndpointError; + + fn try_from(s: &str) -> Result { + s.parse() + } +} + +impl fmt::Display for Transport { + fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { + let s = match self { + Transport::Tcp => "tcp", + }; + write!(f, "{}", s) + } +} diff --git a/src/pub.rs b/src/pub.rs index ac4656a..6eaefc3 100644 --- a/src/pub.rs +++ b/src/pub.rs @@ -160,7 +160,7 @@ impl Socket for PubSocket { let endpoint = endpoint.try_into()?; let Endpoint::Tcp(host, port) = endpoint; - let raw_socket = tokio::net::TcpStream::connect(format!("{}:{}", host, port)).await?; + let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; util::peer_connected(raw_socket, self.backend.clone()).await; Ok(()) } diff --git a/src/sub.rs b/src/sub.rs index cd49f83..57db0de 100644 --- a/src/sub.rs +++ b/src/sub.rs @@ -162,7 +162,7 @@ impl Socket for SubSocket { let endpoint = endpoint.try_into()?; let Endpoint::Tcp(host, port) = endpoint; - let raw_socket = tokio::net::TcpStream::connect(format!("{}:{}", host, port)).await?; + let raw_socket = tokio::net::TcpStream::connect((host.to_string().as_str(), port)).await?; util::peer_connected(raw_socket, self.backend.clone()).await; Ok(()) } diff --git a/src/util.rs b/src/util.rs index 3446f61..cbee146 100644 --- a/src/util.rs +++ b/src/util.rs @@ -205,7 +205,7 @@ pub(crate) async fn start_accepting_connections( ) -> ZmqResult<(Endpoint, futures::channel::oneshot::Sender)> { let Endpoint::Tcp(mut host, port) = endpoint; - let mut listener = tokio::net::TcpListener::bind(format!("{}:{}", host, port)).await?; + let mut listener = tokio::net::TcpListener::bind((host.to_string().as_str(), port)).await?; let resolved_addr = listener.local_addr()?; let (stop_handle, stop_callback) = futures::channel::oneshot::channel::(); tokio::spawn(async move {