|
| 1 | +use std::io::Write; |
| 2 | +use std::net::TcpListener; |
| 3 | +use std::thread::JoinHandle; |
| 4 | +use std::time::{Duration, Instant}; |
| 5 | + |
| 6 | +use spacetimedb_smoketests::Smoketest; |
| 7 | + |
| 8 | +fn module_code_http_disallowed_ip(addr: &str, port: u16) -> String { |
| 9 | + format!( |
| 10 | + r#" |
| 11 | +use spacetimedb::ProcedureContext; |
| 12 | +
|
| 13 | +#[spacetimedb::procedure] |
| 14 | +pub fn request_disallowed_ip(ctx: &mut ProcedureContext) -> Result<(), String> {{ |
| 15 | + match ctx.http.get("http://{addr}:{port}/") {{ |
| 16 | + Ok(_) => Err("request unexpectedly succeeded".to_owned()), |
| 17 | + Err(err) => {{ |
| 18 | + let message = err.to_string(); |
| 19 | + if message.contains("refusing to connect to private or special-purpose addresses") {{ |
| 20 | + Ok(()) |
| 21 | + }} else {{ |
| 22 | + Err(format!("unexpected error from http request: {{message}}")) |
| 23 | + }} |
| 24 | + }} |
| 25 | + }} |
| 26 | +}} |
| 27 | +"# |
| 28 | + ) |
| 29 | +} |
| 30 | + |
| 31 | +fn spawn_redirect_server(location: &str) -> (u16, JoinHandle<std::io::Result<()>>) { |
| 32 | + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("failed to bind test redirect server"); |
| 33 | + listener |
| 34 | + .set_nonblocking(true) |
| 35 | + .expect("failed to set test redirect server nonblocking mode"); |
| 36 | + let port = listener |
| 37 | + .local_addr() |
| 38 | + .expect("failed to read test redirect server address") |
| 39 | + .port(); |
| 40 | + let location = location.to_owned(); |
| 41 | + let handle = std::thread::spawn(move || -> std::io::Result<()> { |
| 42 | + let deadline = Instant::now() + Duration::from_secs(10); |
| 43 | + let (mut stream, _) = loop { |
| 44 | + match listener.accept() { |
| 45 | + Ok(pair) => break pair, |
| 46 | + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { |
| 47 | + if Instant::now() >= deadline { |
| 48 | + return Err(std::io::Error::new( |
| 49 | + std::io::ErrorKind::TimedOut, |
| 50 | + "redirect test server did not receive a request; rebuild standalone with allow_loopback_http_for_tests", |
| 51 | + )); |
| 52 | + } |
| 53 | + std::thread::sleep(Duration::from_millis(10)); |
| 54 | + } |
| 55 | + Err(err) => return Err(err), |
| 56 | + } |
| 57 | + }; |
| 58 | + let response = |
| 59 | + format!("HTTP/1.1 302 Found\r\nLocation: {location}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"); |
| 60 | + stream.write_all(response.as_bytes())?; |
| 61 | + stream.flush()?; |
| 62 | + Ok(()) |
| 63 | + }); |
| 64 | + (port, handle) |
| 65 | +} |
| 66 | + |
| 67 | +#[test] |
| 68 | +fn test_http_disallowed_ip_is_blocked() { |
| 69 | + let module_code = module_code_http_disallowed_ip("10.0.0.1", 80); |
| 70 | + let test = Smoketest::builder().module_code(&module_code).build(); |
| 71 | + |
| 72 | + let output = test.call_output("request_disallowed_ip", &[]); |
| 73 | + let stdout = String::from_utf8_lossy(&output.stdout); |
| 74 | + let stderr = String::from_utf8_lossy(&output.stderr); |
| 75 | + assert!( |
| 76 | + output.status.success(), |
| 77 | + "Expected request_disallowed_ip to succeed after observing blocked egress error.\nstdout:\n{}\nstderr:\n{}", |
| 78 | + stdout, |
| 79 | + stderr |
| 80 | + ); |
| 81 | +} |
| 82 | + |
| 83 | +#[test] |
| 84 | +fn test_http_redirect_to_disallowed_ip_is_blocked() { |
| 85 | + let (port, redirect_server) = spawn_redirect_server("http://10.0.0.1:80/"); |
| 86 | + let module_code = module_code_http_disallowed_ip("localhost", port); |
| 87 | + let test = Smoketest::builder().module_code(&module_code).build(); |
| 88 | + |
| 89 | + let output = test.call_output("request_disallowed_ip", &[]); |
| 90 | + let stdout = String::from_utf8_lossy(&output.stdout); |
| 91 | + let stderr = String::from_utf8_lossy(&output.stderr); |
| 92 | + assert!( |
| 93 | + output.status.success(), |
| 94 | + "Expected request_disallowed_ip to succeed after observing blocked egress error.\nstdout:\n{}\nstderr:\n{}", |
| 95 | + stdout, |
| 96 | + stderr |
| 97 | + ); |
| 98 | + |
| 99 | + redirect_server |
| 100 | + .join() |
| 101 | + .expect("redirect test server thread panicked") |
| 102 | + .expect("redirect test server failed"); |
| 103 | +} |
0 commit comments