diff --git a/.github/workflows/cont_integration.yml b/.github/workflows/cont_integration.yml index 3d2844f..5fe5477 100644 --- a/.github/workflows/cont_integration.yml +++ b/.github/workflows/cont_integration.yml @@ -33,6 +33,7 @@ jobs: run: | cargo update -p openssl --precise "0.10.78" cargo update -p openssl-sys --precise "0.9.114" + cargo update -p zeroize --precise "1.8.2" - name: Test run: cargo test --verbose --all-features @@ -72,6 +73,7 @@ jobs: run: | cargo update -p openssl --precise "0.10.78" cargo update -p openssl-sys --precise "0.9.114" + cargo update -p zeroize --precise "1.8.2" - name: Check features run: cargo check --verbose ${{ matrix.features }} diff --git a/README.md b/README.md index d543c15..b9f554e 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ To build with the MSRV you will need to pin dependencies by running: ``` bash cargo update -p openssl --precise "0.10.78" cargo update -p openssl-sys --precise "0.9.114" +cargo update -p zeroize --precise "1.8.2" ``` ## License diff --git a/src/client.rs b/src/client.rs index 1f39417..268c51e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -34,7 +34,7 @@ pub enum ClientType { /// Generalized Electrum client that supports multiple backends. Can re-instantiate client_type if connections /// drops pub struct Client { - client_type: RwLock, + client_type: RwLock>, config: Config, url: String, } @@ -44,8 +44,9 @@ macro_rules! impl_inner_call { { let mut errors = vec![]; loop { + $self.connect()?; let read_client = $self.client_type.read().unwrap(); - let res = match &*read_client { + let res = match read_client.as_ref().expect("connect populated client") { ClientType::TCP(inner) => inner.$name( $($args, )* ), #[cfg(any(feature = "openssl", feature = "rustls", feature = "rustls-ring"))] ClientType::SSL(inner) => inner.$name( $($args, )* ), @@ -79,7 +80,7 @@ macro_rules! impl_inner_call { match ClientType::from_config(&$self.url, &$self.config) { Ok(new_client) => { info!("Succesfully created new client"); - *write_client = new_client; + *write_client = Some(new_client); break; }, Err(e) => { @@ -199,15 +200,41 @@ impl Client { /// Generic constructor that supports multiple backends and allows configuration through /// the [Config] + /// + /// This stores the URL and configuration without opening a network connection. pub fn from_config(url: &str, config: Config) -> Result { - let client_type = RwLock::new(ClientType::from_config(url, &config)?); + #[cfg(not(any(feature = "openssl", feature = "rustls", feature = "rustls-ring")))] + if url.starts_with("ssl://") { + return Err(Error::Message( + "SSL connections require one of the following features to be enabled: openssl, rustls, or rustls-ring".to_string() + )); + } Ok(Client { - client_type, + client_type: RwLock::new(None), config, url: url.to_string(), }) } + + /// Establishes the Electrum connection and negotiates the protocol version. + /// + /// Does nothing if the client is already connected. + pub fn connect(&self) -> Result<(), Error> { + { + let client_type = self.client_type.read().unwrap(); + if client_type.is_some() { + return Ok(()); + } + } + + let mut client_type = self.client_type.write().unwrap(); + if client_type.is_none() { + *client_type = Some(ClientType::from_config(&self.url, &self.config)?); + } + + Ok(()) + } } impl ElectrumApi for Client { @@ -434,6 +461,28 @@ impl ElectrumApi for Client { #[cfg(test)] mod tests { use super::*; + use std::io::{BufRead, BufReader, ErrorKind, Write}; + use std::net::TcpListener; + use std::thread; + + const VERSION_RESPONSE: &[u8] = br#"{"jsonrpc":"2.0","result":["test-server","1.6"],"id":0}"#; + const FEATURES_RESPONSE: &[u8] = + br#"{"jsonrpc":"2.0","result":{"server_version":"test-server","genesis_hash":"0000000000000000000000000000000000000000000000000000000000000000","protocol_min":"1.4","protocol_max":"1.6","hash_function":"sha256","pruning":null},"id":1}"#; + + fn listener_url(listener: &TcpListener) -> String { + format!("tcp://{}", listener.local_addr().unwrap()) + } + + fn read_request(reader: &mut impl BufRead) -> String { + let mut request = String::new(); + reader.read_line(&mut request).unwrap(); + request + } + + fn write_response(stream: &mut impl Write, response: &[u8]) { + stream.write_all(response).unwrap(); + stream.write_all(b"\n").unwrap(); + } #[test] fn more_failed_attempts_than_retries_means_exhausted() { @@ -464,4 +513,69 @@ mod tests { assert!(!exhausted) } + + #[test] + fn client_constructor_does_not_connect() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + listener.set_nonblocking(true).unwrap(); + let url = listener_url(&listener); + + let client = Client::new(&url).unwrap(); + + assert!(client.client_type.read().unwrap().is_none()); + match listener.accept() { + Ok(_) => panic!("constructor opened a connection"), + Err(err) if err.kind() == ErrorKind::WouldBlock => {} + Err(err) => panic!("unexpected accept error: {err}"), + } + } + + #[test] + fn client_connect_opens_connection_and_negotiates_protocol() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let url = listener_url(&listener); + let server = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let mut reader = BufReader::new(stream.try_clone().unwrap()); + let request = read_request(&mut reader); + write_response(&mut stream, VERSION_RESPONSE); + + request + }); + + let client = Client::new(&url).unwrap(); + + client.connect().unwrap(); + + assert!(client.client_type.read().unwrap().is_some()); + let request = server.join().unwrap(); + assert!(request.contains(r#""method":"server.version""#)); + } + + #[test] + fn first_api_call_connects_before_dispatching_request() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let url = listener_url(&listener); + let server = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let mut reader = BufReader::new(stream.try_clone().unwrap()); + + let version_request = read_request(&mut reader); + write_response(&mut stream, VERSION_RESPONSE); + + let features_request = read_request(&mut reader); + write_response(&mut stream, FEATURES_RESPONSE); + + (version_request, features_request) + }); + + let client = Client::new(&url).unwrap(); + + let features = client.server_features().unwrap(); + + assert_eq!(features.server_version, "test-server"); + let (version_request, features_request) = server.join().unwrap(); + assert!(version_request.contains(r#""method":"server.version""#)); + assert!(features_request.contains(r#""method":"server.features""#)); + } }