From 264c538a7fd1e94da80d52105642c3c06e7a0098 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 16:04:16 -0700 Subject: [PATCH 01/19] Add Optional Batching for Push Updates --- src/config.rs | 12 ++ src/fetch/tests.rs | 6 +- src/main.rs | 34 ++++- src/push/mod.rs | 105 +++++++++++++-- src/push/tests.rs | 232 +++++++++++++++++++++++++++++++-- src/store/adapters/postgres.rs | 39 +++++- src/store/adapters/sqlite.rs | 40 +++++- src/store/traits.rs | 7 +- 8 files changed, 430 insertions(+), 45 deletions(-) diff --git a/src/config.rs b/src/config.rs index fd9b11e8..a953e96a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -308,6 +308,15 @@ pub struct Config { /// Maximum milliseconds to wait before flushing a batch of status updates. pub status_update_interval_ms: u64, + /// Update claimed → processing (dispatch) updates in batches? + pub batch_push_updates: bool, + + /// The size of a batch of dispatch updates. + pub push_update_batch_size: usize, + + /// Maximum milliseconds to wait before flushing a batch of dispatch updates. + pub push_update_interval_ms: u64, + /// The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, @@ -421,6 +430,9 @@ impl Default for Config { batch_status_updates: false, status_update_batch_size: 1, status_update_interval_ms: 100, + batch_push_updates: false, + push_update_batch_size: 1, + push_update_interval_ms: 100, callback_addr: "0.0.0.0".into(), callback_port: 50051, worker_map: [("sentry".into(), "http://127.0.0.1:50052".into())].into(), diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 092d3503..3e89a730 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -98,10 +98,14 @@ impl InflightActivationStore for MockStore { }) } - async fn mark_activation_processing(&self, _id: &str) -> Result<(), Error> { + async fn mark_processing(&self, _id: &str) -> Result<(), Error> { Ok(()) } + async fn mark_processing_batch(&self, _ids: &[String]) -> Result { + unimplemented!() + } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { unimplemented!() } diff --git a/src/main.rs b/src/main.rs index c3648a6d..543334d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; -use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; +use taskbroker::grpc::server::TaskbrokerServer; use taskbroker::kafka::admin::create_missing_topics; use taskbroker::kafka::consumer::start_consumer; use taskbroker::kafka::deserialize::{self, DeserializeConfig}; @@ -27,7 +27,6 @@ use taskbroker::kafka::inflight_activation_writer::{ ActivationWriterConfig, InflightActivationWriter, }; use taskbroker::kafka::os_stream_writer::{OsStream, OsStreamWriter}; -use taskbroker::logging; use taskbroker::metrics; use taskbroker::processing_strategy; use taskbroker::push::PushPool; @@ -40,6 +39,7 @@ use taskbroker::store::traits::InflightActivationStore; use taskbroker::upkeep::upkeep; use taskbroker::{Args, get_version}; use taskbroker::{SERVICE_NAME, flusher}; +use taskbroker::{grpc, logging, push}; async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { @@ -203,7 +203,7 @@ async fn main() -> Result<(), Error> { rx, flusher_config.status_update_batch_size, flusher_config.status_update_interval_ms, - move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), + move |buffer| Box::pin(grpc::server::flush_updates(flusher_store.clone(), buffer)), ) .await }); @@ -265,8 +265,30 @@ async fn main() -> Result<(), Error> { } }); + // Push update flush task + let (push_update_tx, push_update_task) = if config.batch_push_updates { + let (tx, rx) = tokio::sync::mpsc::channel(config.push_update_batch_size.max(1)); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.push_update_batch_size, + flusher_config.push_update_interval_ms, + move |buffer| Box::pin(push::flush_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; + // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(config.clone(), store.clone())); + let push_pool = Arc::new(PushPool::new(config.clone(), store.clone(), push_update_tx)); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads @@ -305,6 +327,10 @@ async fn main() -> Result<(), Error> { departure = departure.on_completion(log_task_completion("status_update_task", task)); } + if let Some(task) = push_update_task { + departure = departure.on_completion(log_task_completion("push_update_task", task)); + } + departure.await; Ok(()) } diff --git a/src/push/mod.rs b/src/push/mod.rs index fbf04018..8a5fc001 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,4 +1,3 @@ -use chrono::Utc; use std::cmp::max; use std::collections::HashMap; use std::future::Future; @@ -8,6 +7,7 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use async_backtrace::framed; +use chrono::Utc; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; use hmac::{Hmac, Mac}; @@ -15,11 +15,12 @@ use prost::Message; use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; use sha2::Sha256; +use tokio::sync::mpsc; use tokio::task::JoinSet; use tonic::async_trait; use tonic::metadata::MetadataValue; use tonic::transport::Channel; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use crate::config::Config; use crate::store::activation::InflightActivation; @@ -100,6 +101,9 @@ pub struct PushPool { /// The receiving end of a channel that accepts task activations. receiver: Receiver<(InflightActivation, Instant)>, + /// Queue for batching claimed → processing updates. + update_tx: Option>, + /// Taskbroker configuration. config: Arc, @@ -111,28 +115,36 @@ pub struct PushPool { impl PushPool { /// Initialize a new push pool. - pub fn new(config: Arc, store: Arc) -> Self { + pub fn new( + config: Arc, + store: Arc, + update_tx: Option>, + ) -> Self { let worker_factory: WorkerFactory = Arc::new(|endpoint: String| { Box::pin(async move { let client = WorkerServiceClient::connect(endpoint).await?; Ok(Box::new(client) as Box) }) }); - Self::new_with_factory(config, store, worker_factory) + + Self::new_with_factory(config, store, worker_factory, update_tx) } fn new_with_factory( config: Arc, store: Arc, worker_factory: WorkerFactory, + update_tx: Option>, ) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); + Self { sender, receiver, config, store, worker_factory, + update_tx, } } @@ -148,6 +160,7 @@ impl PushPool { let receiver = self.receiver.clone(); let store = store.clone(); let worker_factory = worker_factory.clone(); + let update_tx = self.update_tx.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); @@ -244,11 +257,31 @@ impl PushPool { } let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration").record(start.elapsed()); + + // Are we batching claimed → processing updates? + if let Some(ref tx) = update_tx { + let result = tx.send(id.clone()).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to enqueue push update" + ); + } + + continue; + } + + // Fall back to individual updates + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); + metrics::counter!("push.mark_processing", "result" => "error").increment(1); error!( task_id = %id, @@ -304,12 +337,13 @@ impl PushPool { debug!(task_id = %id, "Activation sent to worker"); let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration") + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration") .record(start.elapsed()); if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); + metrics::counter!("push.mark_processing", "result" => "error") + .increment(1); error!( task_id = %id, @@ -422,5 +456,56 @@ async fn push_task( result } +pub async fn flush_updates(store: Arc, buffer: &mut Vec) { + if buffer.is_empty() { + return; + } + + let start = Instant::now(); + + let ids = std::mem::take(buffer); + + let requested = ids.len() as u64; + metrics::histogram!("push.flush_updates.requested").record(requested as f64); + + let result = store.mark_processing_batch(&ids).await; + metrics::histogram!("push.mark_processing_batch.duration").record(start.elapsed()); + + match result { + Ok(affected) => { + metrics::histogram!("push.flush_updates.affected").record(affected as f64); + + metrics::counter!("push.flush_updates.updated").increment(affected); + metrics::counter!("push.flush_updates", "result" => "ok").increment(1); + + if affected < requested { + metrics::counter!("push.flush_updates.partial").increment(1); + + warn!( + requested, + affected, "Updated fewer rows than IDs requested from push pool" + ); + } + + debug!(affected, requested, "Flushed update batch from push pool"); + } + + Err(e) => { + metrics::counter!("push.flush_updates", "result" => "error").increment(1); + + error!( + requested, + error = ?e, + "Failed to flush update batch from push pool" + ); + + // Push failed updates back into the buffer so they can be retried on next flush + for id in ids { + buffer.push(id); + } + } + } +} + #[cfg(test)] mod tests; diff --git a/src/push/tests.rs b/src/push/tests.rs index 888b127d..3b5ee0f4 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,10 +1,11 @@ use std::sync::{Arc, Mutex}; +use std::time::Instant; use anyhow::anyhow; use async_trait::async_trait; use chrono::{DateTime, Utc}; use sentry_protos::taskbroker::v1::PushTaskRequest; -use tokio::sync::Notify; +use tokio::sync::{Notify, mpsc}; use tokio::time::{Duration, timeout}; use crate::config::Config; @@ -63,15 +64,42 @@ impl WorkerClient for NotifyingWorkerClient { } /// Minimal fake store that records which activation IDs have been marked as processing. -#[derive(Default, Clone)] +/// All IDs marked via either `mark_processing` or successful `mark_processing_batch`. +#[derive(Clone)] struct MockStore { marked_processing: Arc>>, + mark_processing_calls: Arc>>, + mark_processing_batches: Arc>>>, + mark_processing_batch_should_fail: Arc>, +} + +impl Default for MockStore { + fn default() -> Self { + Self { + marked_processing: Arc::new(Mutex::new(vec![])), + mark_processing_calls: Arc::new(Mutex::new(vec![])), + mark_processing_batches: Arc::new(Mutex::new(vec![])), + mark_processing_batch_should_fail: Arc::new(Mutex::new(false)), + } + } } impl MockStore { fn marked_ids(&self) -> Vec { self.marked_processing.lock().unwrap().clone() } + + fn mark_processing_direct_calls(&self) -> Vec { + self.mark_processing_calls.lock().unwrap().clone() + } + + fn mark_processing_batch_calls(&self) -> Vec> { + self.mark_processing_batches.lock().unwrap().clone() + } + + fn set_mark_processing_batch_fail(&self, fail: bool) { + *self.mark_processing_batch_should_fail.lock().unwrap() = fail; + } } #[async_trait] @@ -79,9 +107,11 @@ impl InflightActivationStore for MockStore { async fn store(&self, _batch: Vec) -> anyhow::Result { Ok(0) } + fn assign_partitions(&self, _partitions: Vec) -> anyhow::Result<()> { Ok(()) } + async fn claim_activations( &self, _application: Option<&str>, @@ -92,10 +122,35 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(vec![]) } - async fn mark_activation_processing(&self, id: &str) -> anyhow::Result<()> { + + async fn mark_processing(&self, id: &str) -> anyhow::Result<()> { + self.mark_processing_calls + .lock() + .unwrap() + .push(id.to_string()); self.marked_processing.lock().unwrap().push(id.to_string()); Ok(()) } + + async fn mark_processing_batch(&self, ids: &[String]) -> anyhow::Result { + if *self.mark_processing_batch_should_fail.lock().unwrap() { + return Err(anyhow!("mock mark_processing_batch failure")); + } + + self.mark_processing_batches + .lock() + .unwrap() + .push(ids.to_vec()); + + let mut guard = self.marked_processing.lock().unwrap(); + + for id in ids { + guard.push(id.clone()); + } + + Ok(ids.len() as u64) + } + async fn set_status( &self, _id: &str, @@ -103,6 +158,7 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(None) } + async fn set_status_batch( &self, _ids: &[String], @@ -110,18 +166,23 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result { Ok(0) } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { 0.0 } + async fn count_by_status(&self, _status: InflightActivationStatus) -> anyhow::Result { Ok(0) } + async fn count(&self) -> anyhow::Result { Ok(0) } + async fn get_by_id(&self, _id: &str) -> anyhow::Result> { Ok(None) } + async fn set_processing_deadline( &self, _id: &str, @@ -129,51 +190,66 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result<()> { Ok(()) } + async fn delete_activation(&self, _id: &str) -> anyhow::Result<()> { Ok(()) } + async fn vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn full_vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn db_size(&self) -> anyhow::Result { Ok(0) } + async fn get_retry_activations(&self) -> anyhow::Result> { Ok(vec![]) } + async fn handle_claim_expiration(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_deadline(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_attempts(&self) -> anyhow::Result { Ok(0) } + async fn handle_expires_at(&self) -> anyhow::Result { Ok(0) } + async fn handle_delay_until(&self) -> anyhow::Result { Ok(0) } + async fn handle_failed_tasks(&self) -> anyhow::Result { Ok(FailedTasksForwarder { to_discard: vec![], to_deadletter: vec![], }) } + async fn mark_completed(&self, _ids: Vec) -> anyhow::Result { Ok(0) } + async fn remove_completed(&self) -> anyhow::Result { Ok(0) } + async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> anyhow::Result { Ok(0) } + async fn clear(&self) -> anyhow::Result<()> { Ok(()) } @@ -269,7 +345,7 @@ async fn push_pool_submit_enqueues_item() { }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); + let pool = PushPool::new(config, store, None); let activation = make_activations(1).remove(0); let time = Instant::now(); @@ -285,7 +361,7 @@ async fn push_pool_submit_backpressures_when_queue_full() { }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); + let pool = PushPool::new(config, store, None); let time = Instant::now(); let first = make_activations(1).remove(0); @@ -321,7 +397,7 @@ async fn push_pool_start_worker_connect_failure_returns_error() { ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = PushPool::new_with_factory(config, store, failing_connect_factory()); + let pool = PushPool::new_with_factory(config, store, failing_connect_factory(), None); let result = pool.start().await; assert!( @@ -331,7 +407,7 @@ async fn push_pool_start_worker_connect_failure_returns_error() { } /// After a successful push for a first-attempt activation (processing_attempts == 0), -/// mark_activation_processing must be called on the store. +/// mark_processing must be called on the store. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_first_attempt() { let notify = Arc::new(Notify::new()); @@ -346,6 +422,7 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { config, store.clone(), notifying_factory(false, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -361,7 +438,7 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { .await .expect("submit should succeed"); - // Wait for the worker to call send(), then give it time to call mark_activation_processing + // Wait for the worker to call send(), then give it time to call mark_processing timeout(Duration::from_secs(2), notify.notified()) .await .expect("timed out waiting for push to be delivered"); @@ -370,12 +447,12 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { assert_eq!( store.marked_ids(), vec![id], - "mark_activation_processing should be called after a successful first-attempt push" + "mark_processing should be called after a successful first-attempt push" ); } /// After a successful push for a retried activation (processing_attempts > 0), -/// mark_activation_processing must be called and latency recording is skipped. +/// mark_processing must be called and latency recording is skipped. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_retry() { let notify = Arc::new(Notify::new()); @@ -390,6 +467,7 @@ async fn push_pool_start_marks_activation_processing_on_retry() { config, store.clone(), notifying_factory(false, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -413,13 +491,13 @@ async fn push_pool_start_marks_activation_processing_on_retry() { assert_eq!( store.marked_ids(), vec![id], - "mark_activation_processing should be called after a successful retry push" + "mark_processing should be called after a successful retry push" ); } -/// When the worker fails to deliver an activation, mark_activation_processing must NOT be called. +/// When the worker fails to deliver an activation, mark_processing must NOT be called. #[tokio::test] -async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { +async fn push_pool_start_does_not_mark_processing_on_push_failure() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { worker_map: [("sentry".into(), "unused".into())].into(), @@ -432,6 +510,7 @@ async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { config, store.clone(), notifying_factory(true, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -451,6 +530,131 @@ async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { assert!( store.marked_ids().is_empty(), - "mark_activation_processing should not be called when push fails" + "mark_processing should not be called when push fails" ); } + +/// With `update_tx` set, a successful push on the main loop enqueues the task ID on the channel. +/// Shutdown drain does not use batching - it applies `mark_processing` per activation, so this test +/// does not assert on direct `mark_processing` calls (those can appear only from drain under shutdown). +#[tokio::test] +async fn push_pool_forwards_successful_push_to_update_channel() { + let notify = Arc::new(Notify::new()); + let (update_tx, mut update_rx) = mpsc::channel::(8); + + let config = Arc::new(Config { + worker_map: [("sentry".into(), "unused".into())].into(), + push_threads: 1, + push_queue_size: 10, + ..Config::default() + }); + let store = Arc::new(MockStore::default()); + let pool = Arc::new(PushPool::new_with_factory( + config, + store.clone(), + notifying_factory(false, notify.clone()), + Some(update_tx), + )); + + let pool_start = pool.clone(); + tokio::spawn(async move { pool_start.start().await }); + + let activation = make_activations(1).remove(0); + let id = activation.id.clone(); + let time = Instant::now(); + + pool.submit(activation, time) + .await + .expect("Submit should succeed"); + + timeout(Duration::from_secs(2), notify.notified()) + .await + .expect("Timed out waiting for push to be delivered"); + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!( + store.mark_processing_batch_calls().is_empty(), + "Method `mark_processing_batch` runs only via `flush_updates`, not the push worker" + ); + + let ch_id = update_rx + .recv() + .await + .expect("Task ID should be sent on update channel"); + assert_eq!(ch_id, id); +} + +/// Function `flush_updates` drains the buffer into `mark_processing_batch` and clears the buffer. +#[tokio::test] +async fn flush_updates_applies_batch_and_clears_buffer() { + let store = Arc::new(MockStore::default()); + let mut buf = vec!["id_0".to_string()]; + + flush_updates(store.clone(), &mut buf).await; + + assert!( + buf.is_empty(), + "buffer should be cleared after successful flush" + ); + assert!(store.mark_processing_direct_calls().is_empty()); + assert_eq!( + store.mark_processing_batch_calls(), + vec![vec!["id_0".to_string()]] + ); + assert_eq!(store.marked_ids(), vec!["id_0".to_string()]); +} + +/// On `mark_processing_batch` error, `flush_updates` restores IDs into the buffer for retry. +#[tokio::test] +async fn flush_updates_restores_buffer_on_batch_error() { + let store = Arc::new(MockStore::default()); + store.set_mark_processing_batch_fail(true); + + let mut buf = vec!["a".to_string(), "b".to_string()]; + flush_updates(store.clone(), &mut buf).await; + + assert_eq!(buf, vec!["a".to_string(), "b".to_string()]); + assert!(store.mark_processing_batch_calls().is_empty()); + assert!(store.marked_ids().is_empty()); +} + +/// After a successful worker push, a closed `update_tx` receiver means the main loop cannot enqueue +/// the id (no `mark_processing` there). Shutdown drain still uses individual `mark_processing` for +/// any remaining queue items, so this test does not assert the mock store stayed untouched. +#[tokio::test] +async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_closed() { + let notify = Arc::new(Notify::new()); + let (update_tx, update_rx) = mpsc::channel::(8); + drop(update_rx); + + let config = Arc::new(Config { + worker_map: [("sentry".into(), "unused".into())].into(), + push_threads: 1, + push_queue_size: 10, + ..Config::default() + }); + let store = Arc::new(MockStore::default()); + let pool = Arc::new(PushPool::new_with_factory( + config, + store.clone(), + notifying_factory(false, notify.clone()), + Some(update_tx), + )); + + let pool_start = pool.clone(); + tokio::spawn(async move { pool_start.start().await }); + + let activation = make_activations(1).remove(0); + let time = Instant::now(); + + pool.submit(activation, time) + .await + .expect("Submit should succeed"); + + timeout(Duration::from_secs(2), notify.notified()) + .await + .expect("Timed out waiting for push to be delivered"); + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!(store.mark_processing_batch_calls().is_empty()); +} diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 27ce276f..c82adf59 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -503,10 +503,8 @@ impl InflightActivationStore for PostgresActivationStore { #[instrument(skip_all)] #[framed] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( @@ -523,20 +521,47 @@ impl InflightActivationStore for PostgresActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as processing, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + #[framed] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let result = sqlx::query(&format!( + "UPDATE inflight_taskactivations SET + status = $1, + processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), + claim_expires_at = NULL + WHERE id = ANY($2) AND status = $3", + )) + .bind(InflightActivationStatus::Processing.to_string()) + .bind(ids) + .bind(InflightActivationStatus::Claimed.to_string()) + .execute(&mut *conn) + .await?; + + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index de457de4..8bca255b 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -597,10 +597,8 @@ impl InflightActivationStore for SqliteActivationStore { } #[instrument(skip_all)] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( @@ -617,20 +615,48 @@ impl InflightActivationStore for SqliteActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as sent, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations SET status = "); + query_builder.push_bind(InflightActivationStatus::Processing); + query_builder.push(format!( + ", processing_deadline = unixepoch('now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds'), claim_expires_at = NULL WHERE status = ", + )); + query_builder.push_bind(InflightActivationStatus::Claimed); + query_builder.push(" AND id IN ("); + + let mut separated = query_builder.separated(", "); + for id in ids.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + + let result = query_builder.build().execute(&mut *conn).await?; + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/traits.rs b/src/store/traits.rs index d8bdb5e0..287fc12f 100644 --- a/src/store/traits.rs +++ b/src/store/traits.rs @@ -27,7 +27,7 @@ pub trait InflightActivationStore: Send + Sync { mark_processing: bool, ) -> Result, Error>; - /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_activation_processing` moves to `Processing`. + /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_processing` moves to `Processing`. async fn claim_activations_for_push( &self, limit: Option, @@ -69,7 +69,10 @@ pub trait InflightActivationStore: Send + Sync { } /// Record successful push. - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error>; + async fn mark_processing(&self, id: &str) -> Result<(), Error>; + + /// Record a batch of successful pushes. + async fn mark_processing_batch(&self, ids: &[String]) -> Result; /// Update the status of a specific activation async fn set_status( From 26dfa5f764886e7fff028ed268ff71f5fe7b544d Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 16:47:33 -0700 Subject: [PATCH 02/19] Fix Unbounded Buffering Bug --- src/fetch/mod.rs | 3 +-- src/flusher.rs | 16 +++++++++++++++- src/push/mod.rs | 2 ++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index e4ae2fab..89b12981 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -122,11 +122,10 @@ impl FetchPool { } _ = async { - let start = Instant::now(); - debug!("Fetching next batch of pending activations..."); metrics::counter!("fetch.loop.count").increment(1); + let start = Instant::now(); let mut backoff = false; let result = store.claim_activations_for_push(limit, bucket).await; diff --git a/src/flusher.rs b/src/flusher.rs index 33732dfa..b0e7cd5a 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -29,7 +29,8 @@ where loop { tokio::select! { - msg = rx.recv() => { + // When the buffer is NOT full, try to receive another message + msg = rx.recv(), if buffer.len() < batch_size => { match msg { Some(v) => { buffer.push(v); @@ -51,6 +52,19 @@ where } } + // If the buffer IS full, the branch above will never execute, and we will never + // discover that the channel is now closed, which is why this branch is necessary + _ = std::future::ready(()), if rx.is_closed() => { + while let Ok(update) = rx.try_recv() { + // Buffer may grow beyond configured limit, which is OK because we are shutting down + buffer.push(update); + } + + flush(&mut buffer).await; + break; + } + + // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds _ = interval.tick() => { if !buffer.is_empty() { flush(&mut buffer).await; diff --git a/src/push/mod.rs b/src/push/mod.rs index 8a5fc001..6a2535eb 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -337,6 +337,8 @@ impl PushPool { debug!(task_id = %id, "Activation sent to worker"); let start = Instant::now(); + + // We won't batch these updates to keep things simple during shutdown let result = store.mark_processing(&id).await; metrics::histogram!("push.mark_processing.duration") .record(start.elapsed()); From 04b79622c658c2e5bc5509f4e87910d431d39dc1 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 16:55:43 -0700 Subject: [PATCH 03/19] Add Debug Logs for Flusher --- src/flusher.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/flusher.rs b/src/flusher.rs index b0e7cd5a..741dae33 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -4,6 +4,7 @@ use std::time::Duration; use anyhow::Result; use tokio::sync::mpsc::Receiver; +use tracing::debug; /// Run flusher that receives values of type T from a channel and flushes /// them using the provided async `flush` function either when the batch is @@ -31,6 +32,8 @@ where tokio::select! { // When the buffer is NOT full, try to receive another message msg = rx.recv(), if buffer.len() < batch_size => { + debug!("Buffer is NOT full, receiving a message..."); + match msg { Some(v) => { buffer.push(v); @@ -40,12 +43,14 @@ where } if buffer.len() >= batch_size { + debug!("Flushing full buffer..."); flush(&mut buffer).await; } } None => { // Channel closed (shutdown), flush remaining and exit + debug!("Channel closed due to shutdown, flushing remaining before exit..."); flush(&mut buffer).await; break; } @@ -55,6 +60,8 @@ where // If the buffer IS full, the branch above will never execute, and we will never // discover that the channel is now closed, which is why this branch is necessary _ = std::future::ready(()), if rx.is_closed() => { + debug!("Channel is closed and buffer is full, draining channel before exiting..."); + while let Ok(update) = rx.try_recv() { // Buffer may grow beyond configured limit, which is OK because we are shutting down buffer.push(update); @@ -66,6 +73,8 @@ where // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds _ = interval.tick() => { + debug!("Performing periodic flush..."); + if !buffer.is_empty() { flush(&mut buffer).await; } From 9c87519b8fe0bd11f9633d512cd954b11f58ad5f Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 17:01:15 -0700 Subject: [PATCH 04/19] Minor Formatting --- src/push/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/push/mod.rs b/src/push/mod.rs index 6a2535eb..fb7acdbc 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -464,7 +464,6 @@ pub async fn flush_updates(store: Arc, buffer: &mut } let start = Instant::now(); - let ids = std::mem::take(buffer); let requested = ids.len() as u64; From 27ae1de4d573e4e0aefb0d0e942473f5fe3d53cf Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 18:56:53 -0700 Subject: [PATCH 05/19] Fix Hanging Test by Batching on Push Pool Drain --- src/push/mod.rs | 42 +++++++++++++++++++++++++++++------------- src/push/tests.rs | 11 +++++------ 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/push/mod.rs b/src/push/mod.rs index fb7acdbc..05612c07 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -338,20 +338,36 @@ impl PushPool { let start = Instant::now(); - // We won't batch these updates to keep things simple during shutdown - let result = store.mark_processing(&id).await; - metrics::histogram!("push.mark_processing.duration") - .record(start.elapsed()); + if let Some(ref tx) = update_tx { + let result = tx.send(id.clone()).await; + metrics::histogram!("push.mark_processing.duration") + .record(start.elapsed()); - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error") - .increment(1); + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error") + .increment(1); - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as processing after push" - ); + error!( + task_id = %id, + error = ?e, + "Failed to enqueue push update during shutdown drain" + ); + } + } else { + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration") + .record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error") + .increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to mark activation as processing after push" + ); + } } } @@ -464,7 +480,7 @@ pub async fn flush_updates(store: Arc, buffer: &mut } let start = Instant::now(); - let ids = std::mem::take(buffer); + let ids: Vec<_> = std::mem::take(buffer); let requested = ids.len() as u64; metrics::histogram!("push.flush_updates.requested").record(requested as f64); diff --git a/src/push/tests.rs b/src/push/tests.rs index 3b5ee0f4..e88458eb 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -534,9 +534,7 @@ async fn push_pool_start_does_not_mark_processing_on_push_failure() { ); } -/// With `update_tx` set, a successful push on the main loop enqueues the task ID on the channel. -/// Shutdown drain does not use batching - it applies `mark_processing` per activation, so this test -/// does not assert on direct `mark_processing` calls (those can appear only from drain under shutdown). +/// With `update_tx` set, a successful push enqueues the task ID on the channel. #[tokio::test] async fn push_pool_forwards_successful_push_to_update_channel() { let notify = Arc::new(Notify::new()); @@ -618,9 +616,8 @@ async fn flush_updates_restores_buffer_on_batch_error() { assert!(store.marked_ids().is_empty()); } -/// After a successful worker push, a closed `update_tx` receiver means the main loop cannot enqueue -/// the id (no `mark_processing` there). Shutdown drain still uses individual `mark_processing` for -/// any remaining queue items, so this test does not assert the mock store stayed untouched. +/// After a successful worker push, a closed `update_tx` receiver means neither the main loop nor +/// shutdown drain can enqueue the ID. #[tokio::test] async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_closed() { let notify = Arc::new(Notify::new()); @@ -657,4 +654,6 @@ async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_clos tokio::time::sleep(Duration::from_millis(50)).await; assert!(store.mark_processing_batch_calls().is_empty()); + assert!(store.mark_processing_direct_calls().is_empty()); + assert!(store.marked_ids().is_empty()); } From 67a692d733b825270c3fb1555d6793856ff4e0f5 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 13:35:32 -0700 Subject: [PATCH 06/19] Move Push Logic into Seprate Function, Add Update Queue Size Metrics, Bias Flusher --- src/flusher.rs | 2 + src/grpc/server.rs | 3 + src/push/mod.rs | 253 +++++++++++++++------------------------------ src/push/tests.rs | 27 +---- 4 files changed, 92 insertions(+), 193 deletions(-) diff --git a/src/flusher.rs b/src/flusher.rs index 741dae33..c6208675 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -30,6 +30,8 @@ where loop { tokio::select! { + biased; + // When the buffer is NOT full, try to receive another message msg = rx.recv(), if buffer.len() < batch_size => { debug!("Buffer is NOT full, receiving a message..."); diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 9d669c5e..3f68de7d 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -107,6 +107,9 @@ impl ConsumerService for TaskbrokerServer { } if let Some(ref tx) = self.update_tx { + let depth = tx.max_capacity() - tx.capacity(); + metrics::gauge!("grpc_server.update_queue.depth").set(depth as f64); + tx.send((id, status)) .await .map_err(|_| Status::internal("Status update channel closed"))?; diff --git a/src/push/mod.rs b/src/push/mod.rs index 05612c07..ae1547c9 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,4 +1,3 @@ -use std::cmp::max; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -7,7 +6,6 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use async_backtrace::framed; -use chrono::Utc; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; use hmac::{Hmac, Mac}; @@ -164,11 +162,6 @@ impl PushPool { let guard = get_shutdown_guard().shutdown_on_drop(); - let callback_url = format!( - "{}:{}", - self.config.callback_addr, self.config.callback_port - ); - let timeout = Duration::from_millis(self.config.push_timeout_ms); let grpc_shared_secret = self.config.grpc_shared_secret.clone(); @@ -215,174 +208,22 @@ impl PushPool { metrics::histogram!("push.queue.latency").record(time.elapsed()); - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; - - match push_task( - worker.as_mut(), - activation.clone(), - callback_url, - timeout, - grpc_shared_secret.as_slice(), - ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - if activation.processing_attempts < 1 { - let latency = max(0, activation.received_latency(Utc::now())); - - metrics::histogram!( - "push.received_to_push.latency", - "namespace" => activation.namespace, - "taskname" => activation.taskname, - ) - .record(latency as f64); - } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); - } - - let start = Instant::now(); - - // Are we batching claimed → processing updates? - if let Some(ref tx) = update_tx { - let result = tx.send(id.clone()).await; - metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to enqueue push update" - ); - } - - continue; - } - - // Fall back to individual updates - let result = store.mark_processing(&id).await; - metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as sent after push" - ); - } - } - - // Once claim expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; + push_task(store.clone(), update_tx.as_ref(), activation, &mut workers, timeout, grpc_shared_secret.as_slice()).await; } } } // Drain channel before exiting without recording duration metrics since they don't matter at this time for (activation, _) in receiver.drain() { - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; - - match push_task( - worker.as_mut(), + push_task( + store.clone(), + update_tx.as_ref(), activation, - callback_url, + &mut workers, timeout, grpc_shared_secret.as_slice(), ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - let start = Instant::now(); - - if let Some(ref tx) = update_tx { - let result = tx.send(id.clone()).await; - metrics::histogram!("push.mark_processing.duration") - .record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error") - .increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to enqueue push update during shutdown drain" - ); - } - } else { - let result = store.mark_processing(&id).await; - metrics::histogram!("push.mark_processing.duration") - .record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error") - .increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as processing after push" - ); - } - } - } - - // Once processing deadline expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error") - .increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; + .await; } Ok(()) @@ -438,12 +279,86 @@ impl PushPool { } } -/// Decode task activation and push it to a worker. -#[framed] +/// Determine which worker should receive an activation, send the activation, and update its status. async fn push_task( + store: Arc, + update_tx: Option<&mpsc::Sender>, + activation: InflightActivation, + workers: &mut HashMap>, + timeout: Duration, + grpc_shared_secret: &[String], +) { + let id = activation.id.clone(); + + let Some(worker) = workers.get_mut(&activation.application) else { + metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); + + error!( + task_id = %id, + application = activation.application, + "Task application has no worker pool mapping" + ); + + return; + }; + + match send_task(worker.as_mut(), activation, timeout, grpc_shared_secret).await { + Ok(_) => { + metrics::counter!("push.push_task", "result" => "ok").increment(1); + debug!(task_id = %id, "Activation sent to worker"); + + let start = Instant::now(); + + if let Some(tx) = update_tx { + let depth = tx.max_capacity() - tx.capacity(); + metrics::gauge!("push.update_queue.depth").set(depth as f64); + + let result = tx.send(id.clone()).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to enqueue push update during shutdown drain" + ); + } + } else { + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to mark activation as processing after push" + ); + } + } + } + + // Once processing deadline expires, status will be set back to pending + Err(e) => { + metrics::counter!("push.push_task", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to send activation to worker" + ) + } + }; +} + +/// Decode task activation and send it to the worker service for a particular application. +#[framed] +async fn send_task( worker: &mut (dyn WorkerClient + Send), activation: InflightActivation, - callback_url: String, timeout: Duration, grpc_shared_secret: &[String], ) -> Result<()> { @@ -461,7 +376,7 @@ async fn push_task( let request = PushTaskRequest { task: Some(task), - callback_url, + callback_url: "".into(), }; let result = match tokio::time::timeout(timeout, worker.send(request, grpc_shared_secret)).await diff --git a/src/push/tests.rs b/src/push/tests.rs index e88458eb..8a3f2e8d 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -279,14 +279,7 @@ async fn push_task_returns_ok_on_client_success() { let mut worker = MockWorkerClient::new(false); let callback_url = "taskbroker:50051".to_string(); - let result = push_task( - &mut worker, - activation.clone(), - callback_url.clone(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation.clone(), Duration::from_secs(5), &[]).await; assert!(result.is_ok(), "push_task should succeed"); assert_eq!(worker.captured_requests.len(), 1); @@ -304,14 +297,7 @@ async fn push_task_returns_err_on_invalid_payload() { activation.activation = vec![1, 2, 3, 4]; let mut worker = MockWorkerClient::new(false); - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; assert!(result.is_err(), "invalid payload should fail decoding"); assert!( @@ -325,14 +311,7 @@ async fn push_task_propagates_client_error() { let activation = make_activations(1).remove(0); let mut worker = MockWorkerClient::new(true); - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; assert!(result.is_err(), "worker send errors should propagate"); assert_eq!(worker.captured_requests.len(), 1); } From c2112fe62b9ecdbbd9f3abe4ef880566bbe3adc5 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 13:42:26 -0700 Subject: [PATCH 07/19] Fix Callback URL Issues --- src/config.rs | 4 ++-- src/push/mod.rs | 1 + src/push/tests.rs | 2 -- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index a953e96a..ed8a3da8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -317,10 +317,10 @@ pub struct Config { /// Maximum milliseconds to wait before flushing a batch of dispatch updates. pub push_update_interval_ms: u64, - /// The hostname used to construct `callback_url` for task push requests. + /// (DEPRECATED) The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, - /// The port used to construct `callback_url` for task push requests. + /// (DEPRECATED) The port used to construct `callback_url` for task push requests. pub callback_port: u32, /// Maps every application to its worker endpoint, both represented as strings. diff --git a/src/push/mod.rs b/src/push/mod.rs index ae1547c9..cbde41c8 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -374,6 +374,7 @@ async fn send_task( } }; + // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema let request = PushTaskRequest { task: Some(task), callback_url: "".into(), diff --git a/src/push/tests.rs b/src/push/tests.rs index 8a3f2e8d..a59bc826 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -277,14 +277,12 @@ fn failing_connect_factory() -> WorkerFactory { async fn push_task_returns_ok_on_client_success() { let activation = make_activations(1).remove(0); let mut worker = MockWorkerClient::new(false); - let callback_url = "taskbroker:50051".to_string(); let result = send_task(&mut worker, activation.clone(), Duration::from_secs(5), &[]).await; assert!(result.is_ok(), "push_task should succeed"); assert_eq!(worker.captured_requests.len(), 1); let request = &worker.captured_requests[0]; - assert_eq!(request.callback_url, callback_url); assert_eq!( request.task.as_ref().map(|task| task.id.as_str()), Some(activation.id.as_str()) From 81f949d3effa55477dacff215cb435854df1be81 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 14:20:15 -0700 Subject: [PATCH 08/19] Fix Misleading Error Message --- src/push/mod.rs | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/push/mod.rs b/src/push/mod.rs index cbde41c8..aee8bb8e 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -6,6 +7,7 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use async_backtrace::framed; +use chrono::Utc; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; use hmac::{Hmac, Mac}; @@ -302,11 +304,31 @@ async fn push_task( return; }; - match send_task(worker.as_mut(), activation, timeout, grpc_shared_secret).await { + match send_task( + worker.as_mut(), + activation.clone(), + timeout, + grpc_shared_secret, + ) + .await + { Ok(_) => { metrics::counter!("push.push_task", "result" => "ok").increment(1); debug!(task_id = %id, "Activation sent to worker"); + if activation.processing_attempts < 1 { + let latency = max(0, activation.received_latency(Utc::now())); + + metrics::histogram!( + "push.received_to_push.latency", + "namespace" => activation.namespace, + "taskname" => activation.taskname, + ) + .record(latency as f64); + } else { + debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); + } + let start = Instant::now(); if let Some(tx) = update_tx { @@ -322,7 +344,7 @@ async fn push_task( error!( task_id = %id, error = ?e, - "Failed to enqueue push update during shutdown drain" + "Failed to enqueue push update" ); } } else { From 88dce44a9bbca47f698d3de9ef85f7d75725aee0 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 19:40:41 -0700 Subject: [PATCH 09/19] Changes to Flusher to Ensure Smooth Shutdown --- src/flusher.rs | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/flusher.rs b/src/flusher.rs index c6208675..8d43bc82 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -3,6 +3,7 @@ use std::pin::Pin; use std::time::Duration; use anyhow::Result; +use elegant_departure::get_shutdown_guard; use tokio::sync::mpsc::Receiver; use tracing::debug; @@ -28,6 +29,8 @@ where let mut buffer: Vec = Vec::with_capacity(batch_size); + let guard = get_shutdown_guard().shutdown_on_drop(); + loop { tokio::select! { biased; @@ -51,38 +54,37 @@ where } None => { - // Channel closed (shutdown), flush remaining and exit - debug!("Channel closed due to shutdown, flushing remaining before exit..."); - flush(&mut buffer).await; + // Channel closed + debug!("Channel closed!"); break; } } } - // If the buffer IS full, the branch above will never execute, and we will never - // discover that the channel is now closed, which is why this branch is necessary - _ = std::future::ready(()), if rx.is_closed() => { - debug!("Channel is closed and buffer is full, draining channel before exiting..."); + // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds + _ = interval.tick() => { + debug!("Performing periodic flush..."); - while let Ok(update) = rx.try_recv() { - // Buffer may grow beyond configured limit, which is OK because we are shutting down - buffer.push(update); + if rx.is_closed() { + debug!("Channel closed on tick!"); + break; } flush(&mut buffer).await; - break; } - // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds - _ = interval.tick() => { - debug!("Performing periodic flush..."); - - if !buffer.is_empty() { - flush(&mut buffer).await; - } + _ = guard.wait() => { + debug!("Shutdown guard triggered!"); + break; } } } + // Drain and flush before exit + while let Ok(update) = rx.try_recv() { + buffer.push(update); + } + + flush(&mut buffer).await; Ok(()) } From bb5b66ef2ff9ae35ca6a4f979559aff018b7a7ce Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 13:19:48 -0700 Subject: [PATCH 10/19] Break Long Lines --- src/fetch/mod.rs | 11 +++++++++-- src/push/mod.rs | 8 +++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 89b12981..d898bf07 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -48,7 +48,8 @@ pub fn bucket_range_for_fetch_thread(thread_index: usize, fetch_threads: usize) (low, high) } -/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, +/// but it also decouples fetch logic from push logic even further. #[async_trait] pub trait TaskPusher { /// Submit a single task to the push pool. @@ -164,7 +165,13 @@ impl FetchPool { ) .record(latency as f64); } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → claimed latency recording"); + debug!( + task_id = %id, + namespace = activation.namespace, + taskname = activation.taskname, + "Activation already processed, skipping \ + received → claimed latency recording" + ); } match pusher.submit_task(activation, start).await { diff --git a/src/push/mod.rs b/src/push/mod.rs index aee8bb8e..31c0636b 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -326,7 +326,13 @@ async fn push_task( ) .record(latency as f64); } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); + debug!( + task_id = %id, + namespace = activation.namespace, + taskname = activation.taskname, + "Activation already processed, skipping \ + received → push latency recording" + ); } let start = Instant::now(); From bee392a8f4d5083ff5b7308957924abd57fe6dcb Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 14:41:41 -0700 Subject: [PATCH 11/19] Improve Claim Expiration Computation --- src/config.rs | 6 ++++++ src/store/adapters/postgres.rs | 28 +++++++++++++++++++++++++--- src/store/adapters/sqlite.rs | 24 +++++++++++++++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index ed8a3da8..75325df1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -222,6 +222,11 @@ pub struct Config { /// brokers are under load, or there are small networking delays. pub processing_deadline_grace_sec: u64, + /// The number of additional seconds that claim expirations + /// are extended by. This helps reduce claim expirations when + /// brokers are under load, or there are small networking delays. + pub claim_expiration_grace_sec: u64, + /// The frequency at which upkeep tasks /// (discarding, retrying activations, etc.) are executed. pub upkeep_task_interval_ms: u64, @@ -406,6 +411,7 @@ impl Default for Config { max_processing_count: 2048, max_processing_attempts: 5, processing_deadline_grace_sec: 3, + claim_expiration_grace_sec: 3, upkeep_task_interval_ms: 1000, upkeep_unhealthy_interval_ms: 5000, health_check_killswitched: false, diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index c82adf59..47c99c4d 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -154,6 +154,29 @@ impl PostgresActivationStoreConfig { url.as_ref().split('?').next().unwrap().to_string() + "?" + extra_query_params; conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); } + + // Compute the longest amount of time an activation may be claimed + let claim_lease_ms = { + // In the worst case, every activation in the batch will time out when appending to the push queue + let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; + + // In the worst case, every activation in the push queue will time out when sending + let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; + + let update_ms = if config.batch_push_updates { + // In the worst case, we will need to wait an entire interval before flushing a batch of push updates + config.push_update_interval_ms + } else { + // Grace seconds will cover the update query duration until we decide to implement query timeouts + 0 + }; + + // Account for grace seconds specified in configuration + let grace_ms = config.claim_expiration_grace_sec * 1000; + + queue_ms + send_ms + update_ms + grace_ms + }; + Self { pg_connection: conn_opts, pg_database_name: config.pg_database_name.clone(), @@ -161,9 +184,9 @@ impl PostgresActivationStoreConfig { run_migrations: config.run_migrations, max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, - processing_deadline_grace_sec: config.processing_deadline_grace_sec, - claim_lease_ms: config.fetch_batch_size.max(1) as u64 * config.push_queue_timeout_ms, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + processing_deadline_grace_sec: config.processing_deadline_grace_sec, + claim_lease_ms, } } } @@ -505,7 +528,6 @@ impl InflightActivationStore for PostgresActivationStore { #[framed] async fn mark_processing(&self, id: &str) -> Result<(), Error> { let mut conn = self.acquire_write_conn_metric("mark_processing").await?; - let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( "UPDATE inflight_taskactivations SET diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index 8bca255b..e8a8e3eb 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -147,12 +147,34 @@ pub struct InflightActivationStoreConfig { impl InflightActivationStoreConfig { pub fn from_config(config: &Config) -> Self { + // Compute the longest amount of time an activation may be claimed + let claim_lease_ms = { + // In the worst case, every activation in the batch will time out when appending to the push queue + let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; + + // In the worst case, every activation in the push queue will time out when sending + let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; + + let update_ms = if config.batch_push_updates { + // In the worst case, we will need to wait an entire interval before flushing a batch of push updates + config.push_update_interval_ms + } else { + // Grace seconds will cover the update query duration until we decide to implement query timeouts + 0 + }; + + // Account for grace seconds specified in configuration + let grace_ms = config.claim_expiration_grace_sec * 1000; + + queue_ms + send_ms + update_ms + grace_ms + }; + Self { max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, processing_deadline_grace_sec: config.processing_deadline_grace_sec, - claim_lease_ms: config.fetch_batch_size.max(1) as u64 * config.push_queue_timeout_ms, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + claim_lease_ms, } } } From 0956223cc71463ede4fa349ea210cc08f5161ed7 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 15:32:43 -0700 Subject: [PATCH 12/19] Fix Flusher Drain Shutdown Logic --- src/flusher.rs | 14 ++++++-------- src/main.rs | 3 +-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/flusher.rs b/src/flusher.rs index 8d43bc82..dbfd32d2 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -28,8 +28,7 @@ where interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut buffer: Vec = Vec::with_capacity(batch_size); - - let guard = get_shutdown_guard().shutdown_on_drop(); + let guard = get_shutdown_guard(); loop { tokio::select! { @@ -54,7 +53,7 @@ where } None => { - // Channel closed + // Channel closed because all senders were dropped debug!("Channel closed!"); break; } @@ -66,17 +65,13 @@ where debug!("Performing periodic flush..."); if rx.is_closed() { + // Channel closed because all senders were dropped debug!("Channel closed on tick!"); break; } flush(&mut buffer).await; } - - _ = guard.wait() => { - debug!("Shutdown guard triggered!"); - break; - } } } @@ -85,6 +80,9 @@ where buffer.push(update); } + // Delay shutdown until we have flushed everything in the buffer flush(&mut buffer).await; + drop(guard); + Ok(()) } diff --git a/src/main.rs b/src/main.rs index 543334d9..8cf145d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -217,7 +217,6 @@ async fn main() -> Result<(), Error> { let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); let grpc_config = config.clone(); - let grpc_status_tx = status_update_tx.clone(); async move { let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) @@ -234,7 +233,7 @@ async fn main() -> Result<(), Error> { .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, config: grpc_config, - update_tx: grpc_status_tx, + update_tx: status_update_tx, })) .add_service(health_service.clone()) .serve(addr); From 8b17becde195315ead5f8650c185aed7abc17654 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 15:47:20 -0700 Subject: [PATCH 13/19] Fix Claim Lease Double Count Grace Seconds --- src/store/adapters/postgres.rs | 9 +++++---- src/store/adapters/sqlite.rs | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 47c99c4d..393fecc5 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -444,9 +444,6 @@ impl InflightActivationStore for PostgresActivationStore { ) -> Result, Error> { let now = Utc::now(); - let grace_period = self.config.processing_deadline_grace_sec; - let claim_lease_ms = self.config.claim_lease_ms as i64; - let mut query_builder = QueryBuilder::::new( "WITH selected_activations AS ( SELECT id @@ -492,6 +489,8 @@ impl InflightActivationStore for PostgresActivationStore { query_builder.push(" FOR UPDATE SKIP LOCKED)"); if mark_processing { + let grace_period = self.config.processing_deadline_grace_sec; + query_builder.push(format!( "UPDATE inflight_taskactivations SET processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), @@ -501,9 +500,11 @@ impl InflightActivationStore for PostgresActivationStore { query_builder.push_bind(InflightActivationStatus::Processing.to_string()); } else { + let claim_lease = self.config.claim_lease_ms as i64; + query_builder.push(format!( "UPDATE inflight_taskactivations - SET claim_expires_at = now() + ({claim_lease_ms} * interval '1 millisecond') + (interval '{grace_period} seconds'), + SET claim_expires_at = now() + ({claim_lease} * interval '1 millisecond'), processing_deadline = NULL, status = " )); diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index e8a8e3eb..9c01e154 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -557,20 +557,22 @@ impl InflightActivationStore for SqliteActivationStore { mark_processing: bool, ) -> Result, Error> { let now = Utc::now(); - let grace_period = self.config.processing_deadline_grace_sec; let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations SET "); if mark_processing { + let grace_period = self.config.processing_deadline_grace_sec; + query_builder.push(format!( "processing_deadline = unixepoch('now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds'), claim_expires_at = NULL, status = " )); query_builder.push_bind(InflightActivationStatus::Processing); } else { + let claim_lease = self.config.claim_lease_ms as f64 / 1000.0; + query_builder.push(format!( - "claim_expires_at = unixepoch('now', '+' || {:.3} || ' seconds', '+' || {grace_period} || ' seconds'), processing_deadline = NULL, status = ", - self.config.claim_lease_ms as f64 / 1000.0, + "claim_expires_at = unixepoch('now', '+' || {claim_lease:.3} || ' seconds'), processing_deadline = NULL, status = " )); query_builder.push_bind(InflightActivationStatus::Claimed); From 9acd2b99137ca3786a461f3f46ec7dc907b81f99 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 15:48:46 -0700 Subject: [PATCH 14/19] Remove Unneeded Comments --- src/store/adapters/postgres.rs | 1 - src/store/adapters/sqlite.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 393fecc5..6210a914 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -135,7 +135,6 @@ pub struct PostgresActivationStoreConfig { pub run_migrations: bool, pub max_processing_attempts: usize, pub processing_deadline_grace_sec: u64, - /// Milliseconds added to `claim_expires_at` before grace: `fetch_batch_size * push_queue_timeout_ms`. pub claim_lease_ms: u64, pub vacuum_page_count: Option, pub enable_sqlite_status_metrics: bool, diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index 9c01e154..d68c6df6 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -139,7 +139,6 @@ pub async fn create_sqlite_pool(url: &str) -> Result<(Pool, Pool pub struct InflightActivationStoreConfig { pub max_processing_attempts: usize, pub processing_deadline_grace_sec: u64, - /// Milliseconds added to `claim_expires_at` before grace: `fetch_batch_size * push_queue_timeout_ms`. pub claim_lease_ms: u64, pub vacuum_page_count: Option, pub enable_sqlite_status_metrics: bool, From afe507f10c2c91372be9ef7400503eb672e66428 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 16:08:31 -0700 Subject: [PATCH 15/19] Only Batch Updates in Push Mode --- src/config.rs | 10 ++++++++-- src/main.rs | 51 +++++++++++++++++++++++++++------------------------ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/config.rs b/src/config.rs index 75325df1..96c26f96 100644 --- a/src/config.rs +++ b/src/config.rs @@ -31,6 +31,12 @@ pub enum DeliveryMode { Push, } +impl DeliveryMode { + pub fn is_push(self) -> bool { + self == DeliveryMode::Push + } +} + #[derive(PartialEq, Debug, Deserialize, Serialize)] pub struct Config { /// The sentry DSN to use for error reporting. @@ -304,7 +310,7 @@ pub struct Config { /// Maximum time in milliseconds for a single push RPC to the worker service. This should be greater than the worker's internal timeout. pub push_timeout_ms: u64, - /// Update statuses from the gRPC server in batches? + /// Update statuses from the gRPC server in batches? Only applies in PUSH mode. pub batch_status_updates: bool, /// The size of a batch of status updates. @@ -313,7 +319,7 @@ pub struct Config { /// Maximum milliseconds to wait before flushing a batch of status updates. pub status_update_interval_ms: u64, - /// Update claimed → processing (dispatch) updates in batches? + /// Update claimed → processing (dispatch) updates in batches? Only applies in PUSH mode. pub batch_push_updates: bool, /// The size of a batch of dispatch updates. diff --git a/src/main.rs b/src/main.rs index 8cf145d4..dec15ffd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use tonic::transport::Server; use tonic_health::ServingStatus; use tracing::{debug, error, info, warn}; -use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; +use taskbroker::config::{Config, DatabaseAdapter}; use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; @@ -192,7 +192,9 @@ async fn main() -> Result<(), Error> { }); // Status update flush task - let (status_update_tx, status_update_task) = if config.batch_status_updates { + let (status_update_tx, status_update_task) = if config.batch_status_updates + && config.delivery_mode.is_push() + { let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size.max(1)); let flusher_store = store.clone(); @@ -265,40 +267,41 @@ async fn main() -> Result<(), Error> { }); // Push update flush task - let (push_update_tx, push_update_task) = if config.batch_push_updates { - let (tx, rx) = tokio::sync::mpsc::channel(config.push_update_batch_size.max(1)); - - let flusher_store = store.clone(); - let flusher_config = config.clone(); - - let handle = tokio::spawn(async move { - flusher::run_flusher( - rx, - flusher_config.push_update_batch_size, - flusher_config.push_update_interval_ms, - move |buffer| Box::pin(push::flush_updates(flusher_store.clone(), buffer)), - ) - .await - }); - - (Some(tx), Some(handle)) - } else { - (None, None) - }; + let (push_update_tx, push_update_task) = + if config.batch_push_updates && config.delivery_mode.is_push() { + let (tx, rx) = tokio::sync::mpsc::channel(config.push_update_batch_size.max(1)); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.push_update_batch_size, + flusher_config.push_update_interval_ms, + move |buffer| Box::pin(push::flush_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; // Initialize push and fetch pools let push_pool = Arc::new(PushPool::new(config.clone(), store.clone(), push_update_tx)); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads - let push_task = if config.delivery_mode == DeliveryMode::Push { + let push_task = if config.delivery_mode.is_push() { Some(tokio::spawn(async move { push_pool.start().await })) } else { None }; // Initialize fetch threads - let fetch_task = if config.delivery_mode == DeliveryMode::Push { + let fetch_task = if config.delivery_mode.is_push() { Some(tokio::spawn(async move { fetch_pool.start().await })) } else { None From 85d89cf97aa46de9aa0a0989bf7fc39fed724627 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 20 May 2026 02:42:44 -0700 Subject: [PATCH 16/19] Push Pool Abstraction Overhaul --- src/config.rs | 2 +- src/main.rs | 36 +++- src/push/mod.rs | 401 +++++++-------------------------------------- src/push/tests.rs | 4 +- src/push/thread.rs | 184 +++++++++++++++++++++ src/tokio.rs | 4 +- 6 files changed, 283 insertions(+), 348 deletions(-) create mode 100644 src/push/thread.rs diff --git a/src/config.rs b/src/config.rs index 96c26f96..40b0f416 100644 --- a/src/config.rs +++ b/src/config.rs @@ -326,7 +326,7 @@ pub struct Config { pub push_update_batch_size: usize, /// Maximum milliseconds to wait before flushing a batch of dispatch updates. - pub push_update_interval_ms: u64, + pub push_update_interval_ms: u32, /// (DEPRECATED) The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, diff --git a/src/main.rs b/src/main.rs index dec15ffd..de88bf38 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -278,7 +279,7 @@ async fn main() -> Result<(), Error> { flusher::run_flusher( rx, flusher_config.push_update_batch_size, - flusher_config.push_update_interval_ms, + flusher_config.push_update_interval_ms as u64, move |buffer| Box::pin(push::flush_updates(flusher_store.clone(), buffer)), ) .await @@ -290,12 +291,41 @@ async fn main() -> Result<(), Error> { }; // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(config.clone(), store.clone(), push_update_tx)); + let push_pool = Arc::new(PushPool::new(config.clone(), store.clone())); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads let push_task = if config.delivery_mode.is_push() { - Some(tokio::spawn(async move { push_pool.start().await })) + let mut workers: Vec = vec![]; + + // For every push thread, create a map from applications to worker connections + for i in config.push_threads { + let map = HashMap::new(); + + for (application, endpoint) in config.worker_map.clone() { + let worker = match Worker::connect(endpoint).await { + Ok(w) => { + metrics::counter!("worker.connect", "result" => "ok", "application" => application.clone()).increment(1); + debug!("Connected to worker!"); + + w + } + + Err(e) => { + metrics::counter!("worker.connect", "result" => "error", "application" => application.clone()).increment(1); + error!(error = ?e, "Failed to connect to worker"); + + return Err(e); + } + }; + + map.insert(application, worker); + } + + workers.push(map); + } + + Some(tokio::spawn(async move { push_pool.start(workers).await })) } else { None }; diff --git a/src/push/mod.rs b/src/push/mod.rs index 31c0636b..bde80358 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -5,7 +5,7 @@ use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, Instant}; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, anyhow}; use async_backtrace::framed; use chrono::Utc; use elegant_departure::get_shutdown_guard; @@ -17,22 +17,21 @@ use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; use sha2::Sha256; use tokio::sync::mpsc; use tokio::task::JoinSet; -use tonic::async_trait; use tonic::metadata::MetadataValue; use tonic::transport::Channel; +use tonic::{Request, async_trait}; use tracing::{debug, error, info, warn}; use crate::config::Config; +use crate::push::thread::PushThread; use crate::store::activation::InflightActivation; use crate::store::traits::InflightActivationStore; +pub mod thread; + type HmacSha256 = Hmac; -type WorkerFactory = Arc< - dyn Fn(String) -> Pin>> + Send>> - + Send - + Sync, ->; +pub type WorkerMap = HashMap>; /// gRPC path for `WorkerService::PushTask` — keep in sync with `sentry_protos` generated client. const WORKER_PUSH_TASK_PATH: &str = "/sentry_protos.taskbroker.v1.WorkerService/PushTask"; @@ -47,47 +46,56 @@ fn sentry_signature_hex(secret: &str, grpc_path: &str, message: &[u8]) -> String hex::encode(mac.finalize().into_bytes()) } -/// Error returned when enqueueing an activation for the push workers fails. -#[derive(Debug)] -#[allow(clippy::large_enum_variant)] -pub enum PushError { - /// The bounded queue stayed full until `push_queue_timeout_ms` elapsed. - Timeout, - - /// Channel disconnected (no receivers) or another failure. - Channel(SendError<(InflightActivation, Instant)>), -} - -/// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. +/// Thin interface for the worker client. It mostly serves to enable proper unit testing, +/// but it also decouples the actual client implementation from our pushing logic. #[async_trait] -trait WorkerClient { +trait WorkerClient: Send + Sync { /// Send a single `PushTaskRequest` to the worker service. - /// When `grpc_shared_secret` is not empty, it signs the request with `grpc_shared_secret[0]` and sets `sentry-signature` metadata (same scheme as Python pull client and broker `AuthLayer`). - async fn send(&mut self, request: PushTaskRequest, grpc_shared_secret: &[String]) - -> Result<()>; + async fn push_task(&mut self, activation: InflightActivation) -> Result<()>; +} + +/// Wrapper around worker connection that provides authentication and timeouts. +struct Worker { + /// Connection to the worker service. + client: WorkerServiceClient, + + /// List of shared secrets. + secrets: Vec, + + /// Wait this much time on push. + timeout: Duration, } #[async_trait] -impl WorkerClient for WorkerServiceClient { +impl WorkerClient for Worker { #[framed] - async fn send( - &mut self, - request: PushTaskRequest, - grpc_shared_secret: &[String], - ) -> Result<()> { - let mut req = tonic::Request::new(request); - - if let Some(secret) = grpc_shared_secret.first() { - let body = req.get_ref().encode_to_vec(); + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + // Try to decode activation + let task = + TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; + + // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema + let request = PushTaskRequest { + task: Some(task), + callback_url: "".into(), + }; + + // Wrap inside a Tonic request + let mut request = Request::new(request); + + // Sign if secrets are present + if let Some(secret) = self.secrets.first() { + let body = request.get_ref().encode_to_vec(); let signature = sentry_signature_hex(secret, WORKER_PUSH_TASK_PATH, &body); let value = MetadataValue::try_from(signature.as_str()) .context("sentry-signature metadata value must be valid ASCII")?; - req.metadata_mut().insert("sentry-signature", value); + + request.metadata_mut().insert("sentry-signature", value); } - self.push_task(req) - .await - .map_err(|status| anyhow::anyhow!(status))?; + // Push with timeout + let future = self.client.push_task(request); + tokio::time::timeout(self.timeout, future).await??; Ok(()) } @@ -101,41 +109,16 @@ pub struct PushPool { /// The receiving end of a channel that accepts task activations. receiver: Receiver<(InflightActivation, Instant)>, - /// Queue for batching claimed → processing updates. - update_tx: Option>, - /// Taskbroker configuration. config: Arc, /// Activation store, which we need for marking tasks as sent. store: Arc, - - worker_factory: WorkerFactory, } impl PushPool { /// Initialize a new push pool. - pub fn new( - config: Arc, - store: Arc, - update_tx: Option>, - ) -> Self { - let worker_factory: WorkerFactory = Arc::new(|endpoint: String| { - Box::pin(async move { - let client = WorkerServiceClient::connect(endpoint).await?; - Ok(Box::new(client) as Box) - }) - }); - - Self::new_with_factory(config, store, worker_factory, update_tx) - } - - fn new_with_factory( - config: Arc, - store: Arc, - worker_factory: WorkerFactory, - update_tx: Option>, - ) -> Self { + pub fn new(config: Arc, store: Arc) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); Self { @@ -143,104 +126,29 @@ impl PushPool { receiver, config, store, - worker_factory, - update_tx, } } /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. #[framed] - pub async fn start(&self) -> Result<()> { - let store = self.store.clone(); - let worker_factory = self.worker_factory.clone(); - let mut push_pool: JoinSet> = crate::tokio::spawn_pool( - self.config.push_threads, - |_| { - let worker_map = self.config.worker_map.clone(); - let receiver = self.receiver.clone(); - let store = store.clone(); - let worker_factory = worker_factory.clone(); - let update_tx = self.update_tx.clone(); - - let guard = get_shutdown_guard().shutdown_on_drop(); - - let timeout = Duration::from_millis(self.config.push_timeout_ms); - let grpc_shared_secret = self.config.grpc_shared_secret.clone(); - - async_backtrace::frame!(async move { - metrics::counter!("push.worker.connect.attempt").increment(1); - - let mut workers: HashMap> = HashMap::new(); - - for (application, endpoint) in worker_map.clone() { - let worker = match worker_factory(endpoint).await { - Ok(w) => { - metrics::counter!("push.worker.connect", "result" => "ok", "application" => application.clone()) - .increment(1); - w - } - - Err(e) => { - metrics::counter!("push.worker.connect", "result" => "error", "application" => application.clone()) - .increment(1); - error!(error = ?e, "Failed to connect to worker"); - - return Err(e); - } - }; - - workers.insert(application, worker); - } - - loop { - tokio::select! { - _ = guard.wait() => { - info!("Push worker received shutdown signal"); - break; - } - - message = receiver.recv_async() => { - let (activation, time) = match message { - // Received activation from fetch thread - Ok(a) => a, - - // Channel closed - Err(_) => break, - }; - - metrics::histogram!("push.queue.latency").record(time.elapsed()); - - push_task(store.clone(), update_tx.as_ref(), activation, &mut workers, timeout, grpc_shared_secret.as_slice()).await; - } - } - } - - // Drain channel before exiting without recording duration metrics since they don't matter at this time - for (activation, _) in receiver.drain() { - push_task( - store.clone(), - update_tx.as_ref(), - activation, - &mut workers, - timeout, - grpc_shared_secret.as_slice(), - ) - .await; - } - - Ok(()) - }) - }, - ); + pub async fn start(&self, workers: Vec) -> Result<()> { + let mut workers = workers.into_iter(); + + let mut push_pool: JoinSet> = + crate::tokio::spawn_pool(self.config.push_threads, |_| { + let mut thread = PushThread::new( + self.config.clone(), + self.store.clone(), + workers.next().unwrap(), + self.receiver.clone(), + ); + + async move { thread.start().await } + }); while let Some(result) = push_pool.join_next().await { match result { - Ok(r) => { - // Connection failed - r? - } - - // Join failed + Ok(r) => r?, Err(e) => return Err(e.into()), } } @@ -281,192 +189,5 @@ impl PushPool { } } -/// Determine which worker should receive an activation, send the activation, and update its status. -async fn push_task( - store: Arc, - update_tx: Option<&mpsc::Sender>, - activation: InflightActivation, - workers: &mut HashMap>, - timeout: Duration, - grpc_shared_secret: &[String], -) { - let id = activation.id.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - return; - }; - - match send_task( - worker.as_mut(), - activation.clone(), - timeout, - grpc_shared_secret, - ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - if activation.processing_attempts < 1 { - let latency = max(0, activation.received_latency(Utc::now())); - - metrics::histogram!( - "push.received_to_push.latency", - "namespace" => activation.namespace, - "taskname" => activation.taskname, - ) - .record(latency as f64); - } else { - debug!( - task_id = %id, - namespace = activation.namespace, - taskname = activation.taskname, - "Activation already processed, skipping \ - received → push latency recording" - ); - } - - let start = Instant::now(); - - if let Some(tx) = update_tx { - let depth = tx.max_capacity() - tx.capacity(); - metrics::gauge!("push.update_queue.depth").set(depth as f64); - - let result = tx.send(id.clone()).await; - metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to enqueue push update" - ); - } - } else { - let result = store.mark_processing(&id).await; - metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as processing after push" - ); - } - } - } - - // Once processing deadline expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; -} - -/// Decode task activation and send it to the worker service for a particular application. -#[framed] -async fn send_task( - worker: &mut (dyn WorkerClient + Send), - activation: InflightActivation, - timeout: Duration, - grpc_shared_secret: &[String], -) -> Result<()> { - let start = Instant::now(); - metrics::counter!("push.push_task.attempts").increment(1); - - // Try to decode activation (if it fails, we will see the error where `push_task` is called) - let task = match TaskActivation::decode(&activation.activation as &[u8]) { - Ok(task) => task, - Err(err) => { - metrics::histogram!("push.push_task.duration").record(start.elapsed()); - return Err(err.into()); - } - }; - - // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema - let request = PushTaskRequest { - task: Some(task), - callback_url: "".into(), - }; - - let result = match tokio::time::timeout(timeout, worker.send(request, grpc_shared_secret)).await - { - Ok(r) => r, - Err(e) => Err(e.into()), - }; - - metrics::histogram!("push.push_task.duration").record(start.elapsed()); - result -} - -pub async fn flush_updates(store: Arc, buffer: &mut Vec) { - if buffer.is_empty() { - return; - } - - let start = Instant::now(); - let ids: Vec<_> = std::mem::take(buffer); - - let requested = ids.len() as u64; - metrics::histogram!("push.flush_updates.requested").record(requested as f64); - - let result = store.mark_processing_batch(&ids).await; - metrics::histogram!("push.mark_processing_batch.duration").record(start.elapsed()); - - match result { - Ok(affected) => { - metrics::histogram!("push.flush_updates.affected").record(affected as f64); - - metrics::counter!("push.flush_updates.updated").increment(affected); - metrics::counter!("push.flush_updates", "result" => "ok").increment(1); - - if affected < requested { - metrics::counter!("push.flush_updates.partial").increment(1); - - warn!( - requested, - affected, "Updated fewer rows than IDs requested from push pool" - ); - } - - debug!(affected, requested, "Flushed update batch from push pool"); - } - - Err(e) => { - metrics::counter!("push.flush_updates", "result" => "error").increment(1); - - error!( - requested, - error = ?e, - "Failed to flush update batch from push pool" - ); - - // Push failed updates back into the buffer so they can be retried on next flush - for id in ids { - buffer.push(id); - } - } - } -} - #[cfg(test)] mod tests; diff --git a/src/push/tests.rs b/src/push/tests.rs index a59bc826..99829fab 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -33,7 +33,7 @@ impl MockWorkerClient { #[async_trait] impl WorkerClient for MockWorkerClient { - async fn send( + async fn push_task( &mut self, request: PushTaskRequest, _grpc_shared_secret: &[String], @@ -54,7 +54,7 @@ struct NotifyingWorkerClient { #[async_trait] impl WorkerClient for NotifyingWorkerClient { - async fn send(&mut self, _request: PushTaskRequest, _: &[String]) -> Result<()> { + async fn push_task(&mut self, _request: PushTaskRequest, _: &[String]) -> Result<()> { self.notify.notify_one(); if self.should_fail { return Err(anyhow!("mock send failure")); diff --git a/src/push/thread.rs b/src/push/thread.rs new file mode 100644 index 00000000..ff9457d0 --- /dev/null +++ b/src/push/thread.rs @@ -0,0 +1,184 @@ +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use chrono::{DateTime, Utc}; +use elegant_departure::get_shutdown_guard; +use flume::Receiver; +use tokio::sync::{Mutex, MutexGuard}; +use tracing::{info, warn}; + +use crate::config::Config; +use crate::push::WorkerMap; +use crate::store::activation::InflightActivation; +use crate::store::traits::InflightActivationStore; + +/// Alias for documentation. +type Application = String; + +/// Alias for ergonomics. +type Submission = (InflightActivation, Instant); + +pub struct PushThread { + /// The taskbroker configuration. + config: Arc, + + /// The activation store. + store: Arc, + + /// Maps every application to its worker service. + workers: WorkerMap, + + /// Last time the buffer was flushed. + last_flush: DateTime, + + /// Sent activations that need to be updated. + buffer: Arc>>, + + /// Queue of claimed activations to be pushed. + queue: Receiver, +} + +impl PushThread { + pub fn new( + config: Arc, + store: Arc, + workers: WorkerMap, + queue: Receiver, + ) -> Self { + let buffer = Arc::new(Mutex::new(vec![])); + let last_flush = Utc::now(); + + Self { + config, + store, + workers, + last_flush, + buffer, + queue, + } + } + + pub async fn start(&mut self) -> Result<()> { + // Exit when shutdown initiated + let guard = get_shutdown_guard().shutdown_on_drop(); + + // Flush every `interval` milliseconds + let period = Duration::from_millis(self.config.push_update_interval_ms as u64); + let mut interval = tokio::time::interval(period); + + loop { + tokio::select! { + _ = guard.wait() => { + info!("Push thread received shutdown signal!"); + break; + } + + _ = interval.tick(), if self.config.batch_push_updates => { + // Lock the ID buffer + let mut buffer = self.buffer.lock().await; + + // Make sure we aren't flushing too soon + let now = Utc::now().timestamp_millis(); + let elapsed = self.last_flush.timestamp_millis() - now; + + if elapsed < (self.config.push_update_interval_ms as i64) { + // Too soon! + continue; + } + + // We can propagate the error upwards here if desired + self.flush(&mut buffer).await; + } + + message = self.queue.recv_async() => { + let (activation, time) = match message { + // Received activation from fetch thread + Ok(a) => a, + + // Channel closed + Err(_) => break, + }; + + metrics::histogram!("push.queue.latency").record(time.elapsed()); + + // Push the task and mark it as processing + self.push_task(activation).await; + } + } + } + + // Drain channel before exiting + let activations: Vec<_> = self.queue.drain().collect(); + + for (activation, time) in activations { + metrics::histogram!("push.queue.latency").record(time.elapsed()); + + // Push the task and mark it as processing + self.push_task(activation).await; + } + + Ok(()) + } + + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + // Store the ID for later since `push_task` claims ownership over `activation` + let id = activation.id.clone(); + + // First, determine the correct worker service + let Some(worker) = self.workers.get_mut(&activation.application) else { + // Missing application to worker mapping + return Ok(()); + }; + + // Then, push the task to that service + worker.push_task(activation).await?; + + // Finally, mark the activation as processing + self.update(id).await + } + + /// Update one activation from claimed to processing. + async fn update(&self, id: String) -> Result<()> { + if self.config.batch_push_updates { + // Lock the ID buffer + let mut buffer = self.buffer.lock().await; + + if buffer.len() >= self.config.push_update_batch_size { + // Flush first + self.flush(&mut buffer).await?; + } + + buffer.push(id); + Ok(()) + } else { + // We aren't batching claimed → processing updates + self.store.mark_processing(&id).await + } + } + + /// Flush buffered activations to the store. Empties the buffer on success, refills on failure. + async fn flush(&self, buffer: &mut MutexGuard<'_, Vec>) -> Result<()> { + let ids: Vec<_> = buffer.drain(..).collect(); + let expected = ids.len() as u64; + + match self.store.mark_processing_batch(&ids).await { + Ok(actual) => { + if actual < expected { + // This may happen if tasks are reverted back to pending OR completed too quickly + warn!( + "Push thread update batch contained {expected} records, but only {actual} were updated" + ); + } + + Ok(()) + } + + Err(e) => { + // Flush failed, return IDs to buffer + buffer.extend(ids); + Err(e) + } + } + } +} diff --git a/src/tokio.rs b/src/tokio.rs index 54c593a9..50fe2ed8 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -3,9 +3,9 @@ use tokio::task::JoinSet; /// Spawns `max(n, 1)` tasks, each running the future produced by `f` with the task's index. /// Returns a [`JoinSet`] containing all spawned tasks. -pub fn spawn_pool(n: usize, f: F) -> JoinSet +pub fn spawn_pool(n: usize, mut f: F) -> JoinSet where - F: Fn(usize) -> Fut, + F: FnMut(usize) -> Fut, Fut: Future + Send + 'static, Fut::Output: Send, { From 6611714e6f56c95bf4115a7256c1398cdb8acf43 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 20 May 2026 12:29:38 -0700 Subject: [PATCH 17/19] More Refactoring (WIP) --- src/fetch/mod.rs | 52 +---- src/fetch/tests.rs | 1 - src/lib.rs | 1 + src/main.rs | 12 +- src/push/mod.rs | 166 +++++--------- src/push/tests.rs | 386 +++++++-------------------------- src/push/thread.rs | 133 ++---------- src/push/updater.rs | 139 ++++++++++++ src/store/adapters/postgres.rs | 26 +-- src/store/adapters/sqlite.rs | 25 +-- src/worker.rs | 143 ++++++++++++ src/worker/tests.rs | 104 +++++++++ 12 files changed, 565 insertions(+), 623 deletions(-) create mode 100644 src/push/updater.rs create mode 100644 src/worker.rs create mode 100644 src/worker/tests.rs diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index d898bf07..6bccce9c 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -11,7 +11,7 @@ use tonic::async_trait; use tracing::{debug, info, warn}; use crate::config::Config; -use crate::push::{PushError, PushPool}; +use crate::push::{PushPool, Pusher}; use crate::store::activation::InflightActivation; use crate::store::traits::InflightActivationStore; use crate::store::types::BucketRange; @@ -48,48 +48,24 @@ pub fn bucket_range_for_fetch_thread(thread_index: usize, fetch_threads: usize) (low, high) } -/// Thin interface for the push pool. It mostly serves to enable proper unit testing, -/// but it also decouples fetch logic from push logic even further. -#[async_trait] -pub trait TaskPusher { - /// Submit a single task to the push pool. - async fn submit_task( - &self, - activation: InflightActivation, - time: Instant, - ) -> Result<(), PushError>; -} - -#[async_trait] -impl TaskPusher for PushPool { - #[framed] - async fn submit_task( - &self, - activation: InflightActivation, - time: Instant, - ) -> Result<(), PushError> { - self.submit(activation, time).await - } -} - /// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches batches of pending activations from the store, passes them to the push pool, and repeats. -pub struct FetchPool { +pub struct FetchPool { /// Inflight activation store. store: Arc, /// Pool of push threads that push activations to the worker service. - pusher: Arc, + pusher: Arc

, /// Taskbroker configuration. config: Arc, } -impl FetchPool { +impl FetchPool

{ /// Initialize a new fetch pool. pub fn new( store: Arc, config: Arc, - pusher: Arc, + pusher: Arc

, ) -> Self { Self { store, @@ -174,10 +150,10 @@ impl FetchPool { ); } - match pusher.submit_task(activation, start).await { + match pusher.push_task(activation, start).await { Ok(()) => metrics::counter!("fetch.submit", "result" => "ok").increment(1), - Err(PushError::Timeout) => { + Err(_) => { metrics::counter!("fetch.submit", "result" => "timeout") .increment(1); @@ -190,20 +166,6 @@ impl FetchPool { // Wait for push queue to empty backoff = true; } - - Err(PushError::Channel(e)) => { - metrics::counter!("fetch.submit", "result" => "channel_error") - .increment(1); - - warn!( - task_id = %id, - error = ?e, - "Submit to push pool failed due to channel error", - ); - - // Wait before trying again - backoff = true; - } } } } diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 3e89a730..aed385c6 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -7,7 +7,6 @@ use tokio::time::{Duration, sleep}; use tonic::async_trait; use crate::config::Config; -use crate::push::PushError; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::{BucketRange, FailedTasksForwarder}; diff --git a/src/lib.rs b/src/lib.rs index 89a17421..33fe2919 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub mod store; pub mod test_utils; pub mod tokio; pub mod upkeep; +pub mod worker; /// Name of the grpc service. /// Using the service type to get a name wasn't working across modules. diff --git a/src/main.rs b/src/main.rs index de88bf38..ff8af47a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; +use taskbroker::push::updater::LazyUpdater; use tokio::signal::unix::SignalKind; use tokio::task::JoinHandle; use tokio::{select, time}; @@ -325,7 +326,16 @@ async fn main() -> Result<(), Error> { workers.push(map); } - Some(tokio::spawn(async move { push_pool.start(workers).await })) + // Create the correct kind of push updater + let updater = if config.batch_push_updates { + Arc::new(LazyUpdater::new(config.clone(), store.clone())) + } else { + Arc::new(EagerUpdater::new(store.clone())) + }; + + Some(tokio::spawn(async move { + push_pool.start(workers, updater).await + })) } else { None }; diff --git a/src/push/mod.rs b/src/push/mod.rs index bde80358..65e70468 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,104 +1,50 @@ -use std::cmp::max; -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, Instant}; -use anyhow::{Context, Result, anyhow}; +use anyhow::Result; use async_backtrace::framed; -use chrono::Utc; -use elegant_departure::get_shutdown_guard; -use flume::{Receiver, SendError, Sender}; -use hmac::{Hmac, Mac}; -use prost::Message; -use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; -use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; -use sha2::Sha256; -use tokio::sync::mpsc; +use flume::{Receiver, Sender}; use tokio::task::JoinSet; -use tonic::metadata::MetadataValue; -use tonic::transport::Channel; -use tonic::{Request, async_trait}; -use tracing::{debug, error, info, warn}; +use tonic::async_trait; use crate::config::Config; use crate::push::thread::PushThread; +use crate::push::updater::Updater; use crate::store::activation::InflightActivation; use crate::store::traits::InflightActivationStore; +use crate::worker::WorkerMap; pub mod thread; +pub mod updater; -type HmacSha256 = Hmac; +// Helper to compute the longest number of milliseconds an activation may be claimed. +pub fn compute_claim_lease_ms(config: &Config) -> u64 { + // In the worst case, every activation in the batch will time out when appending to the push queue + let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; -pub type WorkerMap = HashMap>; + // In the worst case, every activation in the push queue will time out when sending + let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; -/// gRPC path for `WorkerService::PushTask` — keep in sync with `sentry_protos` generated client. -const WORKER_PUSH_TASK_PATH: &str = "/sentry_protos.taskbroker.v1.WorkerService/PushTask"; + let update_ms = if config.batch_push_updates { + // In the worst case, we will need to wait an entire interval before flushing a batch of push updates + config.push_update_interval_ms as u64 + } else { + // Grace seconds will cover the update query duration until we decide to implement query timeouts + 0 + }; -/// HMAC-SHA256(secret, grpc_path + ":" + message), hex-encoded. Matches Python `RequestSignatureInterceptor` and broker [`crate::grpc::auth_middleware`]. -fn sentry_signature_hex(secret: &str, grpc_path: &str, message: &[u8]) -> String { - let mut mac = - HmacSha256::new_from_slice(secret.as_bytes()).expect("HMAC accepts keys of any length"); - mac.update(grpc_path.as_bytes()); - mac.update(b":"); - mac.update(message); - hex::encode(mac.finalize().into_bytes()) -} - -/// Thin interface for the worker client. It mostly serves to enable proper unit testing, -/// but it also decouples the actual client implementation from our pushing logic. -#[async_trait] -trait WorkerClient: Send + Sync { - /// Send a single `PushTaskRequest` to the worker service. - async fn push_task(&mut self, activation: InflightActivation) -> Result<()>; -} - -/// Wrapper around worker connection that provides authentication and timeouts. -struct Worker { - /// Connection to the worker service. - client: WorkerServiceClient, - - /// List of shared secrets. - secrets: Vec, + // Account for grace seconds specified in configuration + let grace_ms = config.claim_expiration_grace_sec * 1000; - /// Wait this much time on push. - timeout: Duration, + queue_ms + send_ms + update_ms + grace_ms } +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, +/// but it also decouples fetch logic from push logic even further. #[async_trait] -impl WorkerClient for Worker { - #[framed] - async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { - // Try to decode activation - let task = - TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; - - // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema - let request = PushTaskRequest { - task: Some(task), - callback_url: "".into(), - }; - - // Wrap inside a Tonic request - let mut request = Request::new(request); - - // Sign if secrets are present - if let Some(secret) = self.secrets.first() { - let body = request.get_ref().encode_to_vec(); - let signature = sentry_signature_hex(secret, WORKER_PUSH_TASK_PATH, &body); - let value = MetadataValue::try_from(signature.as_str()) - .context("sentry-signature metadata value must be valid ASCII")?; - - request.metadata_mut().insert("sentry-signature", value); - } - - // Push with timeout - let future = self.client.push_task(request); - tokio::time::timeout(self.timeout, future).await??; - - Ok(()) - } +pub trait Pusher { + /// Submit a single task to the push pool. + async fn push_task(&self, activation: InflightActivation, time: Instant) -> Result<()>; } /// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. @@ -131,22 +77,32 @@ impl PushPool { /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. #[framed] - pub async fn start(&self, workers: Vec) -> Result<()> { + pub async fn start(&self, workers: Vec, updater: Arc) -> Result<()> { let mut workers = workers.into_iter(); - let mut push_pool: JoinSet> = - crate::tokio::spawn_pool(self.config.push_threads, |_| { - let mut thread = PushThread::new( - self.config.clone(), - self.store.clone(), - workers.next().unwrap(), - self.receiver.clone(), - ); + // Group the asynchronous tasks we spawn in this method using a `JoinSet` + let mut tasks = JoinSet::new(); + + tasks.spawn({ + let updater = updater.clone(); + async move { updater.start().await } + }); + + for _ in 0..self.config.push_threads { + tasks.spawn({ + let mut thread = PushThread { + config: self.config.clone(), + store: self.store.clone(), + workers: workers.next().unwrap(), + receiver: self.receiver.clone(), + updater: updater.clone(), + }; async move { thread.start().await } }); + } - while let Some(result) = push_pool.join_next().await { + while let Some(result) = tasks.join_next().await { match result { Ok(r) => r?, Err(e) => return Err(e.into()), @@ -155,35 +111,33 @@ impl PushPool { Ok(()) } +} - /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. Times out after `config.push_queue_timeout_ms` milliseconds. +#[async_trait] +impl Pusher for PushPool { #[framed] - pub async fn submit( - &self, - activation: InflightActivation, - time: Instant, - ) -> Result<(), PushError> { + async fn push_task(&self, activation: InflightActivation, time: Instant) -> Result<()> { let duration = Duration::from_millis(self.config.push_queue_timeout_ms); let start = Instant::now(); metrics::gauge!("push.queue.depth").set(self.sender.len() as f64); match tokio::time::timeout(duration, self.sender.send_async((activation, time))).await { - Ok(Ok(())) => { - metrics::histogram!("push.queue.wait_duration").record(start.elapsed()); - Ok(()) + // The channel has closed because all receivers were dropped + Ok(Err(e)) => { + // The only way this can happen is if the push threads have already exited, which is an invalid state + unreachable!("{}", e); } - // The channel has a problem - Ok(Err(e)) => { + // The channel was full so the pushing timed out + Err(e) => { metrics::histogram!("push.queue.wait_duration").record(start.elapsed()); - Err(PushError::Channel(e)) + Err(e.into()) } - // The channel was full so the send timed out - Err(_elapsed) => { + Ok(_) => { metrics::histogram!("push.queue.wait_duration").record(start.elapsed()); - Err(PushError::Timeout) + Ok(()) } } } diff --git a/src/push/tests.rs b/src/push/tests.rs index 99829fab..e8a7fa56 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,11 +1,15 @@ +use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::Instant; -use anyhow::anyhow; +use anyhow::{Result, anyhow}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sentry_protos::taskbroker::v1::PushTaskRequest; -use tokio::sync::{Notify, mpsc}; +use hmac::{Hmac, Mac}; +use prost::Message; +use sentry_protos::taskbroker::v1::TaskActivation; +use sha2::Sha256; +use tokio::sync::Notify; use tokio::time::{Duration, timeout}; use crate::config::Config; @@ -13,73 +17,84 @@ use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::FailedTasksForwarder; use crate::test_utils::{create_test_store, make_activations}; +use crate::worker::{WorkerClient, WorkerMap}; -use super::*; +use super::PushPool; -/// Fake worker client that records requests and optionally fails. +/// Fake worker client that records pushed activation IDs and optionally fails. struct MockWorkerClient { - captured_requests: Vec, - should_fail: bool, + /// Pushed activation IDs. + pushed: Vec, + + /// Should `push_task` fail? + fail: bool, } impl MockWorkerClient { - fn new(should_fail: bool) -> Self { + fn new(fail: bool) -> Self { Self { - captured_requests: vec![], - should_fail, + pushed: vec![], + fail, } } } #[async_trait] impl WorkerClient for MockWorkerClient { - async fn push_task( - &mut self, - request: PushTaskRequest, - _grpc_shared_secret: &[String], - ) -> Result<()> { - self.captured_requests.push(request); - if self.should_fail { + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; + self.pushed.push(activation.id); + + if self.fail { return Err(anyhow!("mock send failure")); } + Ok(()) } } -/// Fake worker client that fires a Notify when send() is called. +/// Fake worker client that fires a `Notify` when `push_task` is called. struct NotifyingWorkerClient { - should_fail: bool, + /// Fire off notification when `push_task` is called notify: Arc, + + /// Should `push_task` fail? + fail: bool, } #[async_trait] impl WorkerClient for NotifyingWorkerClient { - async fn push_task(&mut self, _request: PushTaskRequest, _: &[String]) -> Result<()> { + async fn push_task(&mut self, _activation: InflightActivation) -> Result<()> { self.notify.notify_one(); - if self.should_fail { + + if self.fail { return Err(anyhow!("mock send failure")); } + Ok(()) } } +/// Create a map of notifying worker clients for tests. +fn test_worker_map(fail: bool, notify: Arc) -> WorkerMap { + let mut workers = HashMap::new(); + + let client = NotifyingWorkerClient { fail, notify }; + + workers.insert("sentry".into(), Box::new(client) as Box); + workers +} + /// Minimal fake store that records which activation IDs have been marked as processing. -/// All IDs marked via either `mark_processing` or successful `mark_processing_batch`. #[derive(Clone)] struct MockStore { marked_processing: Arc>>, - mark_processing_calls: Arc>>, - mark_processing_batches: Arc>>>, - mark_processing_batch_should_fail: Arc>, } impl Default for MockStore { fn default() -> Self { Self { marked_processing: Arc::new(Mutex::new(vec![])), - mark_processing_calls: Arc::new(Mutex::new(vec![])), - mark_processing_batches: Arc::new(Mutex::new(vec![])), - mark_processing_batch_should_fail: Arc::new(Mutex::new(false)), } } } @@ -88,18 +103,6 @@ impl MockStore { fn marked_ids(&self) -> Vec { self.marked_processing.lock().unwrap().clone() } - - fn mark_processing_direct_calls(&self) -> Vec { - self.mark_processing_calls.lock().unwrap().clone() - } - - fn mark_processing_batch_calls(&self) -> Vec> { - self.mark_processing_batches.lock().unwrap().clone() - } - - fn set_mark_processing_batch_fail(&self, fail: bool) { - *self.mark_processing_batch_should_fail.lock().unwrap() = fail; - } } #[async_trait] @@ -124,24 +127,11 @@ impl InflightActivationStore for MockStore { } async fn mark_processing(&self, id: &str) -> anyhow::Result<()> { - self.mark_processing_calls - .lock() - .unwrap() - .push(id.to_string()); self.marked_processing.lock().unwrap().push(id.to_string()); Ok(()) } async fn mark_processing_batch(&self, ids: &[String]) -> anyhow::Result { - if *self.mark_processing_batch_should_fail.lock().unwrap() { - return Err(anyhow!("mock mark_processing_batch failure")); - } - - self.mark_processing_batches - .lock() - .unwrap() - .push(ids.to_vec()); - let mut guard = self.marked_processing.lock().unwrap(); for id in ids { @@ -255,155 +245,65 @@ impl InflightActivationStore for MockStore { } } -/// Factory that fires `notify` when send() is called, then succeeds or fails per `should_fail`. -fn notifying_factory(should_fail: bool, notify: Arc) -> WorkerFactory { - Arc::new(move |_: String| { - let notify = notify.clone(); - Box::pin(async move { - Ok(Box::new(NotifyingWorkerClient { - should_fail, - notify, - }) as Box) - }) - }) -} - -/// Factory that always fails to connect (simulates a broken endpoint). -fn failing_connect_factory() -> WorkerFactory { - Arc::new(|_: String| Box::pin(async { Err(anyhow::anyhow!("simulated connect failure")) })) -} - -#[tokio::test] -async fn push_task_returns_ok_on_client_success() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(false); - - let result = send_task(&mut worker, activation.clone(), Duration::from_secs(5), &[]).await; - assert!(result.is_ok(), "push_task should succeed"); - assert_eq!(worker.captured_requests.len(), 1); - - let request = &worker.captured_requests[0]; - assert_eq!( - request.task.as_ref().map(|task| task.id.as_str()), - Some(activation.id.as_str()) - ); -} +// --- PushPool tests --- #[tokio::test] -async fn push_task_returns_err_on_invalid_payload() { - let mut activation = make_activations(1).remove(0); - activation.activation = vec![1, 2, 3, 4]; - - let mut worker = MockWorkerClient::new(false); - let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; - - assert!(result.is_err(), "invalid payload should fail decoding"); - assert!( - worker.captured_requests.is_empty(), - "worker should not be called if decode fails" - ); -} - -#[tokio::test] -async fn push_task_propagates_client_error() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(true); - - let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; - assert!(result.is_err(), "worker send errors should propagate"); - assert_eq!(worker.captured_requests.len(), 1); -} - -#[tokio::test] -async fn push_pool_submit_enqueues_item() { +async fn push_pool_push_task_enqueues_item() { let config = Arc::new(Config { push_queue_size: 2, ..Config::default() }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store, None); + let pool = PushPool::new(config, store); let activation = make_activations(1).remove(0); let time = Instant::now(); - let result = pool.submit(activation, time).await; - assert!(result.is_ok(), "submit should enqueue activation"); + let result = pool.push_task(activation, time).await; + assert!(result.is_ok(), "push_task should enqueue activation"); } #[tokio::test] -async fn push_pool_submit_backpressures_when_queue_full() { +async fn push_pool_push_task_backpressures_when_queue_full() { let config = Arc::new(Config { push_queue_size: 1, ..Config::default() }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store, None); + let pool = PushPool::new(config, store); let time = Instant::now(); let first = make_activations(1).remove(0); let second = make_activations(1).remove(0); - pool.submit(first, time) + pool.push_task(first, time) .await - .expect("first submit should fill queue"); - - let second_submit = timeout(Duration::from_millis(50), pool.submit(second, time)).await; - assert!( - second_submit.is_err(), - "second submit should block when queue is full" - ); -} - -#[test] -fn sentry_signature_hex_matches_hmac_contract() { - let digest = sentry_signature_hex("super secret", "/test/path", b"hello"); - assert_eq!( - digest, - "6408482d9e6d4975ada4c0302fda813c5718e571e6f9a2d6e2803cb48528044e" - ); -} - -/// When the worker factory fails to connect, start() returns an error immediately. -#[tokio::test] -async fn push_pool_start_worker_connect_failure_returns_error() { - let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), - push_threads: 1, - push_queue_size: 10, - ..Config::default() - }); - let store = Arc::new(MockStore::default()); - let pool = PushPool::new_with_factory(config, store, failing_connect_factory(), None); + .expect("first push_task should fill queue"); - let result = pool.start().await; + let second_push = timeout(Duration::from_millis(50), pool.push_task(second, time)).await; assert!( - result.is_err(), - "start() should return Err when the worker factory fails to connect" + second_push.is_err(), + "second push_task should time out when queue is full" ); } -/// After a successful push for a first-attempt activation (processing_attempts == 0), -/// mark_processing must be called on the store. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_first_attempt() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), push_threads: 1, push_queue_size: 10, ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(false, notify.clone()), - None, - )); + let pool = Arc::new(PushPool::new(config, store.clone())); + let workers = vec![test_worker_map(false, notify.clone())]; let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); + tokio::spawn(async move { + pool_start.start(workers).await.expect("push pool start"); + }); let activation = make_activations(1).remove(0); assert_eq!(activation.processing_attempts, 0); @@ -411,11 +311,10 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { let id = activation.id.clone(); let time = Instant::now(); - pool.submit(activation, time) + pool.push_task(activation, time) .await - .expect("submit should succeed"); + .expect("push_task should succeed"); - // Wait for the worker to call send(), then give it time to call mark_processing timeout(Duration::from_secs(2), notify.notified()) .await .expect("timed out waiting for push to be delivered"); @@ -428,27 +327,22 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { ); } -/// After a successful push for a retried activation (processing_attempts > 0), -/// mark_processing must be called and latency recording is skipped. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_retry() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), push_threads: 1, push_queue_size: 10, ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(false, notify.clone()), - None, - )); + let pool = Arc::new(PushPool::new(config, store.clone())); + let workers = vec![test_worker_map(false, notify.clone())]; let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); + tokio::spawn(async move { + pool_start.start(workers).await.expect("push pool start"); + }); let mut activation = make_activations(1).remove(0); activation.processing_attempts = 1; @@ -456,9 +350,9 @@ async fn push_pool_start_marks_activation_processing_on_retry() { let id = activation.id.clone(); let time = Instant::now(); - pool.submit(activation, time) + pool.push_task(activation, time) .await - .expect("submit should succeed"); + .expect("push_task should succeed"); timeout(Duration::from_secs(2), notify.notified()) .await @@ -472,33 +366,29 @@ async fn push_pool_start_marks_activation_processing_on_retry() { ); } -/// When the worker fails to deliver an activation, mark_processing must NOT be called. #[tokio::test] async fn push_pool_start_does_not_mark_processing_on_push_failure() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), push_threads: 1, push_queue_size: 10, ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(true, notify.clone()), - None, - )); + let pool = Arc::new(PushPool::new(config, store.clone())); + let workers = vec![test_worker_map(true, notify.clone())]; let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); + tokio::spawn(async move { + pool_start.start(workers).await.expect("push pool start"); + }); let activation = make_activations(1).remove(0); let time = Instant::now(); - pool.submit(activation, time) + pool.push_task(activation, time) .await - .expect("submit should succeed"); + .expect("push_task should succeed"); timeout(Duration::from_secs(2), notify.notified()) .await @@ -510,127 +400,3 @@ async fn push_pool_start_does_not_mark_processing_on_push_failure() { "mark_processing should not be called when push fails" ); } - -/// With `update_tx` set, a successful push enqueues the task ID on the channel. -#[tokio::test] -async fn push_pool_forwards_successful_push_to_update_channel() { - let notify = Arc::new(Notify::new()); - let (update_tx, mut update_rx) = mpsc::channel::(8); - - let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), - push_threads: 1, - push_queue_size: 10, - ..Config::default() - }); - let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(false, notify.clone()), - Some(update_tx), - )); - - let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); - - let activation = make_activations(1).remove(0); - let id = activation.id.clone(); - let time = Instant::now(); - - pool.submit(activation, time) - .await - .expect("Submit should succeed"); - - timeout(Duration::from_secs(2), notify.notified()) - .await - .expect("Timed out waiting for push to be delivered"); - tokio::time::sleep(Duration::from_millis(50)).await; - - assert!( - store.mark_processing_batch_calls().is_empty(), - "Method `mark_processing_batch` runs only via `flush_updates`, not the push worker" - ); - - let ch_id = update_rx - .recv() - .await - .expect("Task ID should be sent on update channel"); - assert_eq!(ch_id, id); -} - -/// Function `flush_updates` drains the buffer into `mark_processing_batch` and clears the buffer. -#[tokio::test] -async fn flush_updates_applies_batch_and_clears_buffer() { - let store = Arc::new(MockStore::default()); - let mut buf = vec!["id_0".to_string()]; - - flush_updates(store.clone(), &mut buf).await; - - assert!( - buf.is_empty(), - "buffer should be cleared after successful flush" - ); - assert!(store.mark_processing_direct_calls().is_empty()); - assert_eq!( - store.mark_processing_batch_calls(), - vec![vec!["id_0".to_string()]] - ); - assert_eq!(store.marked_ids(), vec!["id_0".to_string()]); -} - -/// On `mark_processing_batch` error, `flush_updates` restores IDs into the buffer for retry. -#[tokio::test] -async fn flush_updates_restores_buffer_on_batch_error() { - let store = Arc::new(MockStore::default()); - store.set_mark_processing_batch_fail(true); - - let mut buf = vec!["a".to_string(), "b".to_string()]; - flush_updates(store.clone(), &mut buf).await; - - assert_eq!(buf, vec!["a".to_string(), "b".to_string()]); - assert!(store.mark_processing_batch_calls().is_empty()); - assert!(store.marked_ids().is_empty()); -} - -/// After a successful worker push, a closed `update_tx` receiver means neither the main loop nor -/// shutdown drain can enqueue the ID. -#[tokio::test] -async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_closed() { - let notify = Arc::new(Notify::new()); - let (update_tx, update_rx) = mpsc::channel::(8); - drop(update_rx); - - let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), - push_threads: 1, - push_queue_size: 10, - ..Config::default() - }); - let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(false, notify.clone()), - Some(update_tx), - )); - - let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); - - let activation = make_activations(1).remove(0); - let time = Instant::now(); - - pool.submit(activation, time) - .await - .expect("Submit should succeed"); - - timeout(Duration::from_secs(2), notify.notified()) - .await - .expect("Timed out waiting for push to be delivered"); - tokio::time::sleep(Duration::from_millis(50)).await; - - assert!(store.mark_processing_batch_calls().is_empty()); - assert!(store.mark_processing_direct_calls().is_empty()); - assert!(store.marked_ids().is_empty()); -} diff --git a/src/push/thread.rs b/src/push/thread.rs index ff9457d0..b517375e 100644 --- a/src/push/thread.rs +++ b/src/push/thread.rs @@ -1,102 +1,50 @@ use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use anyhow::Result; -use chrono::{DateTime, Utc}; use elegant_departure::get_shutdown_guard; use flume::Receiver; -use tokio::sync::{Mutex, MutexGuard}; -use tracing::{info, warn}; use crate::config::Config; -use crate::push::WorkerMap; +use crate::push::updater::Updater; use crate::store::activation::InflightActivation; use crate::store::traits::InflightActivationStore; - -/// Alias for documentation. -type Application = String; +use crate::worker::WorkerMap; /// Alias for ergonomics. -type Submission = (InflightActivation, Instant); +pub type Submission = (InflightActivation, Instant); +/// Abstraction for a single push thread. pub struct PushThread { /// The taskbroker configuration. - config: Arc, + pub(super) config: Arc, /// The activation store. - store: Arc, + pub(super) store: Arc, /// Maps every application to its worker service. - workers: WorkerMap, - - /// Last time the buffer was flushed. - last_flush: DateTime, + pub(super) workers: WorkerMap, - /// Sent activations that need to be updated. - buffer: Arc>>, + /// Channel containing claimed activations to be pushed. + pub(super) receiver: Receiver, - /// Queue of claimed activations to be pushed. - queue: Receiver, + /// Entity that marks tasks as processing. + pub(super) updater: Arc, } impl PushThread { - pub fn new( - config: Arc, - store: Arc, - workers: WorkerMap, - queue: Receiver, - ) -> Self { - let buffer = Arc::new(Mutex::new(vec![])); - let last_flush = Utc::now(); - - Self { - config, - store, - workers, - last_flush, - buffer, - queue, - } - } - pub async fn start(&mut self) -> Result<()> { - // Exit when shutdown initiated - let guard = get_shutdown_guard().shutdown_on_drop(); - - // Flush every `interval` milliseconds - let period = Duration::from_millis(self.config.push_update_interval_ms as u64); - let mut interval = tokio::time::interval(period); + let guard = get_shutdown_guard(); loop { + // We cannot exit before every fetch thread has exited, so don't exit on `guard.wait()` here tokio::select! { - _ = guard.wait() => { - info!("Push thread received shutdown signal!"); - break; - } - - _ = interval.tick(), if self.config.batch_push_updates => { - // Lock the ID buffer - let mut buffer = self.buffer.lock().await; - - // Make sure we aren't flushing too soon - let now = Utc::now().timestamp_millis(); - let elapsed = self.last_flush.timestamp_millis() - now; - - if elapsed < (self.config.push_update_interval_ms as i64) { - // Too soon! - continue; - } - - // We can propagate the error upwards here if desired - self.flush(&mut buffer).await; - } - - message = self.queue.recv_async() => { + message = self.receiver.recv_async() => { let (activation, time) = match message { // Received activation from fetch thread Ok(a) => a, - // Channel closed + // We only exit when the push queue is closed, which happens when the fetch pool has shut down Err(_) => break, }; @@ -109,7 +57,7 @@ impl PushThread { } // Drain channel before exiting - let activations: Vec<_> = self.queue.drain().collect(); + let activations: Vec<_> = self.receiver.drain().collect(); for (activation, time) in activations { metrics::histogram!("push.queue.latency").record(time.elapsed()); @@ -118,6 +66,7 @@ impl PushThread { self.push_task(activation).await; } + drop(guard); Ok(()) } @@ -135,50 +84,6 @@ impl PushThread { worker.push_task(activation).await?; // Finally, mark the activation as processing - self.update(id).await - } - - /// Update one activation from claimed to processing. - async fn update(&self, id: String) -> Result<()> { - if self.config.batch_push_updates { - // Lock the ID buffer - let mut buffer = self.buffer.lock().await; - - if buffer.len() >= self.config.push_update_batch_size { - // Flush first - self.flush(&mut buffer).await?; - } - - buffer.push(id); - Ok(()) - } else { - // We aren't batching claimed → processing updates - self.store.mark_processing(&id).await - } - } - - /// Flush buffered activations to the store. Empties the buffer on success, refills on failure. - async fn flush(&self, buffer: &mut MutexGuard<'_, Vec>) -> Result<()> { - let ids: Vec<_> = buffer.drain(..).collect(); - let expected = ids.len() as u64; - - match self.store.mark_processing_batch(&ids).await { - Ok(actual) => { - if actual < expected { - // This may happen if tasks are reverted back to pending OR completed too quickly - warn!( - "Push thread update batch contained {expected} records, but only {actual} were updated" - ); - } - - Ok(()) - } - - Err(e) => { - // Flush failed, return IDs to buffer - buffer.extend(ids); - Err(e) - } - } + self.updater.update(id).await } } diff --git a/src/push/updater.rs b/src/push/updater.rs new file mode 100644 index 00000000..159e6bc5 --- /dev/null +++ b/src/push/updater.rs @@ -0,0 +1,139 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use chrono::{DateTime, Utc}; +use elegant_departure::get_shutdown_guard; +use tokio::sync::{Mutex, MutexGuard}; +use tonic::async_trait; +use tracing::{info, warn}; + +use crate::{config::Config, store::traits::InflightActivationStore}; + +#[async_trait] +pub trait Updater: Send + Sync { + async fn start(&self) -> Result<()>; + + async fn update(&self, id: String) -> Result<()>; +} + +pub struct LazyUpdater { + /// The taskbroker configuration. + config: Arc, + + /// The activation store. + store: Arc, + + /// Last time the buffer was flushed. + last_flush: DateTime, + + /// Sent activations that need to be updated. + buffer: Arc>>, +} + +#[async_trait] +impl Updater for LazyUpdater { + async fn start(&self) -> Result<()> { + let guard = get_shutdown_guard(); + + // Flush every `interval` milliseconds + let period = Duration::from_millis(self.config.push_update_interval_ms as u64); + let mut interval = tokio::time::interval(period); + + loop { + tokio::select! { + _ = guard.wait() => { + info!("Shutting down batched updater..."); + break; + } + + _ = interval.tick(), if self.config.batch_push_updates => { + // Lock the ID buffer + let mut buffer = self.buffer.lock().await; + + // Make sure we aren't flushing too soon + let now = Utc::now().timestamp_millis(); + let elapsed = self.last_flush.timestamp_millis() - now; + + if elapsed < (self.config.push_update_interval_ms as i64) { + // Too soon! + continue; + } + + // We can propagate the error upwards here if desired + self.flush(&mut buffer).await; + } + } + } + + Ok(()) + } + + async fn update(&self, id: String) -> Result<()> { + // Lock the ID buffer + let mut buffer = self.buffer.lock().await; + + if buffer.len() >= self.config.push_update_batch_size { + // Flush first + self.flush(&mut buffer).await?; + } + + buffer.push(id); + Ok(()) + } +} + +impl LazyUpdater { + pub fn new(config: Arc, store: Arc) -> Self { + let buffer = Arc::new(Mutex::new(vec![])); + let last_flush = Utc::now(); + + Self { + config, + store, + buffer, + last_flush, + } + } + + /// Flush buffered activations to the store. Empties the buffer on success, refills on failure. + async fn flush(&self, buffer: &mut MutexGuard<'_, Vec>) -> Result<()> { + let ids: Vec<_> = buffer.drain(..).collect(); + let expected = ids.len() as u64; + + match self.store.mark_processing_batch(&ids).await { + Ok(actual) => { + if actual < expected { + // This may happen if tasks are reverted back to pending OR completed too quickly + warn!( + "Push thread update batch contained {expected} records, but only {actual} were updated" + ); + } + + Ok(()) + } + + Err(e) => { + // Flush failed, return IDs to buffer + buffer.extend(ids); + Err(e) + } + } + } +} + +pub struct EagerUpdater { + /// The activation store. + store: Arc, +} + +#[async_trait] +impl Updater for EagerUpdater { + async fn start(&self) -> Result<()> { + // There is nothing to do in the background, so we can return immediately + todo!() + } + + async fn update(&self, id: String) -> Result<()> { + self.store.mark_processing(&id).await + } +} diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 6210a914..08e3bfed 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -15,6 +15,7 @@ use sentry_protos::taskbroker::v1::OnAttemptsExceeded; use tracing::{instrument, warn}; use crate::config::Config; +use crate::push::compute_claim_lease_ms; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::{BucketRange, DepthCounts, FailedTasksForwarder}; @@ -147,6 +148,7 @@ impl PostgresActivationStoreConfig { .password(&config.pg_password) .host(&config.pg_host) .port(config.pg_port); + if let Some(extra_query_params) = config.pg_extra_query_params.as_ref() { let url = conn_opts.to_url_lossy(); let new_url = @@ -154,28 +156,6 @@ impl PostgresActivationStoreConfig { conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); } - // Compute the longest amount of time an activation may be claimed - let claim_lease_ms = { - // In the worst case, every activation in the batch will time out when appending to the push queue - let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; - - // In the worst case, every activation in the push queue will time out when sending - let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; - - let update_ms = if config.batch_push_updates { - // In the worst case, we will need to wait an entire interval before flushing a batch of push updates - config.push_update_interval_ms - } else { - // Grace seconds will cover the update query duration until we decide to implement query timeouts - 0 - }; - - // Account for grace seconds specified in configuration - let grace_ms = config.claim_expiration_grace_sec * 1000; - - queue_ms + send_ms + update_ms + grace_ms - }; - Self { pg_connection: conn_opts, pg_database_name: config.pg_database_name.clone(), @@ -185,7 +165,7 @@ impl PostgresActivationStoreConfig { vacuum_page_count: config.vacuum_page_count, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, processing_deadline_grace_sec: config.processing_deadline_grace_sec, - claim_lease_ms, + claim_lease_ms: compute_claim_lease_ms(config), } } } diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index d68c6df6..f558e9c6 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -24,6 +24,7 @@ use sentry_protos::taskbroker::v1::OnAttemptsExceeded; use tracing::{instrument, warn}; use crate::config::Config; +use crate::push::compute_claim_lease_ms; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::{BucketRange, FailedTasksForwarder}; @@ -146,34 +147,12 @@ pub struct InflightActivationStoreConfig { impl InflightActivationStoreConfig { pub fn from_config(config: &Config) -> Self { - // Compute the longest amount of time an activation may be claimed - let claim_lease_ms = { - // In the worst case, every activation in the batch will time out when appending to the push queue - let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; - - // In the worst case, every activation in the push queue will time out when sending - let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; - - let update_ms = if config.batch_push_updates { - // In the worst case, we will need to wait an entire interval before flushing a batch of push updates - config.push_update_interval_ms - } else { - // Grace seconds will cover the update query duration until we decide to implement query timeouts - 0 - }; - - // Account for grace seconds specified in configuration - let grace_ms = config.claim_expiration_grace_sec * 1000; - - queue_ms + send_ms + update_ms + grace_ms - }; - Self { max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, processing_deadline_grace_sec: config.processing_deadline_grace_sec, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, - claim_lease_ms, + claim_lease_ms: compute_claim_lease_ms(config), } } } diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 00000000..7f60f0d1 --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,143 @@ +use std::{collections::HashMap, time::Duration}; + +use anyhow::{Context, Result, anyhow}; +use async_backtrace::framed; +use hmac::{Hmac, Mac}; +use prost::Message; +use sentry_protos::taskbroker::v1::{ + PushTaskRequest, TaskActivation, worker_service_client::WorkerServiceClient, +}; +use sha2::Sha256; +use tonic::metadata::MetadataValue; +use tonic::transport::Channel; +use tonic::{Request, async_trait}; + +use crate::store::activation::InflightActivation; + +// Alias for ergonomics. +pub type WorkerMap = HashMap>; + +/// gRPC path for `WorkerService::PushTask` that should be kept in sync with `sentry_protos` generated client. +pub const WORKER_PUSH_TASK_PATH: &str = "/sentry_protos.taskbroker.v1.WorkerService/PushTask"; + +/// HMAC-SHA256(secret, grpc_path + ":" + message), hex-encoded. Matches Python `RequestSignatureInterceptor` and broker [`crate::grpc::auth_middleware`]. +fn sentry_signature_hex(secret: &str, grpc_path: &str, message: &[u8]) -> String { + let mut mac = + Hmac::::new_from_slice(secret.as_bytes()).expect("HMAC accepts keys of any length"); + mac.update(grpc_path.as_bytes()); + mac.update(b":"); + mac.update(message); + hex::encode(mac.finalize().into_bytes()) +} + +/// Thin interface for the worker client. It mostly serves to enable proper unit testing, +/// but it also decouples the actual client implementation from our pushing logic. +#[async_trait] +pub trait WorkerClient: Send + Sync { + /// Send a single activation to the worker service. + async fn push_task(&mut self, activation: InflightActivation) -> Result<()>; +} + +/// Wrapper around worker connection that provides authentication and timeouts. +pub struct Worker { + /// Connection to the worker service. + client: WorkerServiceClient, + + /// List of secrets shared with the worker. + secrets: Vec, + + /// Wait this much time before giving up on a push. + timeout: Duration, +} + +#[async_trait] +impl WorkerClient for Worker { + #[framed] + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + // Try to decode activation + let task = + TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; + + // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema + let request = PushTaskRequest { + task: Some(task), + callback_url: "".into(), + }; + + // Wrap inside a Tonic request + let mut request = Request::new(request); + + // Sign if secrets are present + if let Some(secret) = self.secrets.first() { + let body = request.get_ref().encode_to_vec(); + let signature = sentry_signature_hex(secret, WORKER_PUSH_TASK_PATH, &body); + let value = MetadataValue::try_from(signature.as_str()) + .context("sentry-signature metadata value must be valid ASCII")?; + + request.metadata_mut().insert("sentry-signature", value); + } + + // Push with timeout + let future = self.client.push_task(request); + tokio::time::timeout(self.timeout, future).await??; + + Ok(()) + } +} + +mod tests { + use hmac::Hmac; + use sha2::Sha256; + + use crate::{test_utils::make_activations, worker::WORKER_PUSH_TASK_PATH}; + + #[tokio::test] + async fn worker_push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + + let result = worker.push_task(activation.clone()).await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.pushed, vec![activation.id]); + } + + #[tokio::test] + async fn worker_push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = worker.push_task(activation).await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.pushed.is_empty(), + "worker should not record a push if decode fails" + ); + } + + #[tokio::test] + async fn worker_push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = worker.push_task(activation.clone()).await; + assert!(result.is_err(), "worker push errors should propagate"); + assert_eq!(worker.pushed, vec![activation.id]); + } + + #[test] + fn worker_sentry_signature_hex_matches_hmac_contract() { + let mut mac = Hmac::::new_from_slice(b"super secret") + .expect("HMAC accepts keys of any length"); + mac.update(WORKER_PUSH_TASK_PATH.as_bytes()); + mac.update(b":"); + mac.update(b"hello"); + let digest = hex::encode(mac.finalize().into_bytes()); + + assert_eq!( + digest, + "6408482d9e6d4975ada4c0302fda813c5718e571e6f9a2d6e2803cb48528044e" + ); + } +} diff --git a/src/worker/tests.rs b/src/worker/tests.rs new file mode 100644 index 00000000..edd7d872 --- /dev/null +++ b/src/worker/tests.rs @@ -0,0 +1,104 @@ +use std::time::Duration; + +use anyhow::anyhow; +use async_trait::async_trait; +use prost::Message; +use sentry_protos::taskbroker::v1::TaskActivation; + +use crate::store::activation::InflightActivation; +use crate::test_utils::make_activations; +use crate::worker::{Worker, WorkerClient, sentry_signature_hex}; + +/// Fake worker client that records task IDs and optionally fails. +struct MockWorkerClient { + captured_task_ids: Vec, + should_fail: bool, +} + +impl MockWorkerClient { + fn new(should_fail: bool) -> Self { + Self { + captured_task_ids: vec![], + should_fail, + } + } +} + +#[async_trait] +impl WorkerClient for MockWorkerClient { + async fn push_task(&mut self, activation: InflightActivation) -> anyhow::Result<()> { + let task = + TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; + self.captured_task_ids.push(task.id); + + if self.should_fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } +} + +#[tokio::test] +async fn push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + + let result = worker.push_task(activation.clone()).await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.captured_task_ids, vec![activation.id]); +} + +#[tokio::test] +async fn push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = worker.push_task(activation).await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.captured_task_ids.is_empty(), + "worker should not record a task id if decode fails" + ); +} + +#[tokio::test] +async fn push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = worker.push_task(activation.clone()).await; + assert!(result.is_err(), "worker send errors should propagate"); + assert_eq!(worker.captured_task_ids, vec![activation.id]); +} + +#[test] +fn sentry_signature_hex_matches_hmac_contract() { + let digest = sentry_signature_hex("super secret", "/test/path", b"hello"); + assert_eq!( + digest, + "6408482d9e6d4975ada4c0302fda813c5718e571e6f9a2d6e2803cb48528044e" + ); +} + +#[tokio::test] +async fn worker_connect_fails_for_unreachable_endpoint() { + let result = Worker::connect("http://127.0.0.1:1".into()).await; + assert!( + result.is_err(), + "connect should return Err for an unreachable endpoint" + ); +} + +#[tokio::test] +async fn worker_connect_with_options_fails_for_unreachable_endpoint() { + let result = + Worker::connect_with_options("http://127.0.0.1:1".into(), vec![], Duration::from_secs(1)) + .await; + assert!( + result.is_err(), + "connect_with_options should return Err for an unreachable endpoint" + ); +} From addc616085febf334caf5d3b615936d67c095eb8 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 20 May 2026 16:08:57 -0700 Subject: [PATCH 18/19] More Work --- src/fetch/mod.rs | 4 +- src/fetch/tests.rs | 10 +--- src/main.rs | 52 ++++------------- src/push/mod.rs | 44 ++++++++------- src/push/tests.rs | 126 ++++++++++++----------------------------- src/push/thread.rs | 22 ++++---- src/push/updater.rs | 135 ++++++++++++++++++++++++++++---------------- src/worker.rs | 108 ++++++++++++++++++++++++++++++++--- 8 files changed, 272 insertions(+), 229 deletions(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 6bccce9c..eeedaf2c 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -7,12 +7,10 @@ use async_backtrace::framed; use chrono::Utc; use elegant_departure::get_shutdown_guard; use tokio::time::sleep; -use tonic::async_trait; use tracing::{debug, info, warn}; use crate::config::Config; -use crate::push::{PushPool, Pusher}; -use crate::store::activation::InflightActivation; +use crate::push::Pusher; use crate::store::traits::InflightActivationStore; use crate::store::types::BucketRange; diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index aed385c6..444f6bfa 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -207,16 +207,12 @@ impl RecordingPusher { } #[async_trait] -impl TaskPusher for RecordingPusher { - async fn submit_task( - &self, - activation: InflightActivation, - _time: Instant, - ) -> Result<(), PushError> { +impl Pusher for RecordingPusher { + async fn push_task(&self, activation: InflightActivation, _time: Instant) -> Result<()> { self.pushed_ids.lock().await.push(activation.id.clone()); if self.fail { - return Err(PushError::Timeout); + return Err(anyhow!("timeout")); } Ok(()) diff --git a/src/main.rs b/src/main.rs index ff8af47a..fd6d220e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,6 @@ use anyhow::{Error, anyhow}; use chrono::Utc; use clap::Parser; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; -use taskbroker::push::updater::LazyUpdater; use tokio::signal::unix::SignalKind; use tokio::task::JoinHandle; use tokio::{select, time}; @@ -32,6 +31,7 @@ use taskbroker::kafka::os_stream_writer::{OsStream, OsStreamWriter}; use taskbroker::metrics; use taskbroker::processing_strategy; use taskbroker::push::PushPool; +use taskbroker::push::updater::{EagerUpdater, LazyUpdater, Updater}; use taskbroker::runtime_config::RuntimeConfigManager; use taskbroker::store::adapters::postgres::{ PostgresActivationStore, PostgresActivationStoreConfig, @@ -39,9 +39,10 @@ use taskbroker::store::adapters::postgres::{ use taskbroker::store::adapters::sqlite::{InflightActivationStoreConfig, SqliteActivationStore}; use taskbroker::store::traits::InflightActivationStore; use taskbroker::upkeep::upkeep; +use taskbroker::worker::{Worker, WorkerClient, WorkerMap}; use taskbroker::{Args, get_version}; use taskbroker::{SERVICE_NAME, flusher}; -use taskbroker::{grpc, logging, push}; +use taskbroker::{grpc, logging}; async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { @@ -268,31 +269,8 @@ async fn main() -> Result<(), Error> { } }); - // Push update flush task - let (push_update_tx, push_update_task) = - if config.batch_push_updates && config.delivery_mode.is_push() { - let (tx, rx) = tokio::sync::mpsc::channel(config.push_update_batch_size.max(1)); - - let flusher_store = store.clone(); - let flusher_config = config.clone(); - - let handle = tokio::spawn(async move { - flusher::run_flusher( - rx, - flusher_config.push_update_batch_size, - flusher_config.push_update_interval_ms as u64, - move |buffer| Box::pin(push::flush_updates(flusher_store.clone(), buffer)), - ) - .await - }); - - (Some(tx), Some(handle)) - } else { - (None, None) - }; - // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(config.clone(), store.clone())); + let push_pool = Arc::new(PushPool::new(config.clone())); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads @@ -300,22 +278,18 @@ async fn main() -> Result<(), Error> { let mut workers: Vec = vec![]; // For every push thread, create a map from applications to worker connections - for i in config.push_threads { - let map = HashMap::new(); + for _ in 0..config.push_threads { + let mut map = HashMap::new(); for (application, endpoint) in config.worker_map.clone() { - let worker = match Worker::connect(endpoint).await { + let worker = match Worker::connect(config.clone(), endpoint).await { Ok(w) => { - metrics::counter!("worker.connect", "result" => "ok", "application" => application.clone()).increment(1); debug!("Connected to worker!"); - - w + Box::new(w) as Box } Err(e) => { - metrics::counter!("worker.connect", "result" => "error", "application" => application.clone()).increment(1); error!(error = ?e, "Failed to connect to worker"); - return Err(e); } }; @@ -328,9 +302,11 @@ async fn main() -> Result<(), Error> { // Create the correct kind of push updater let updater = if config.batch_push_updates { - Arc::new(LazyUpdater::new(config.clone(), store.clone())) + let lazy = LazyUpdater::new(config.clone(), store.clone()); + Arc::new(lazy) as Arc } else { - Arc::new(EagerUpdater::new(store.clone())) + let eager = EagerUpdater::new(store.clone()); + Arc::new(eager) as Arc }; Some(tokio::spawn(async move { @@ -369,10 +345,6 @@ async fn main() -> Result<(), Error> { departure = departure.on_completion(log_task_completion("status_update_task", task)); } - if let Some(task) = push_update_task { - departure = departure.on_completion(log_task_completion("push_update_task", task)); - } - departure.await; Ok(()) } diff --git a/src/push/mod.rs b/src/push/mod.rs index 65e70468..c196684a 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -11,7 +11,6 @@ use crate::config::Config; use crate::push::thread::PushThread; use crate::push::updater::Updater; use crate::store::activation::InflightActivation; -use crate::store::traits::InflightActivationStore; use crate::worker::WorkerMap; pub mod thread; @@ -57,21 +56,17 @@ pub struct PushPool { /// Taskbroker configuration. config: Arc, - - /// Activation store, which we need for marking tasks as sent. - store: Arc, } impl PushPool { /// Initialize a new push pool. - pub fn new(config: Arc, store: Arc) -> Self { + pub fn new(config: Arc) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); Self { sender, receiver, config, - store, } } @@ -80,19 +75,15 @@ impl PushPool { pub async fn start(&self, workers: Vec, updater: Arc) -> Result<()> { let mut workers = workers.into_iter(); - // Group the asynchronous tasks we spawn in this method using a `JoinSet` - let mut tasks = JoinSet::new(); - - tasks.spawn({ + // Start the updater + let updaterd = tokio::spawn({ let updater = updater.clone(); async move { updater.start().await } }); - for _ in 0..self.config.push_threads { - tasks.spawn({ + let mut threads: JoinSet> = + crate::tokio::spawn_pool(self.config.push_threads, |_| { let mut thread = PushThread { - config: self.config.clone(), - store: self.store.clone(), workers: workers.next().unwrap(), receiver: self.receiver.clone(), updater: updater.clone(), @@ -100,13 +91,28 @@ impl PushPool { async move { thread.start().await } }); - } - while let Some(result) = tasks.join_next().await { - match result { - Ok(r) => r?, - Err(e) => return Err(e.into()), + while let Some(result) = threads.join_next().await { + if let Err(_) = result { + todo!() } + + if let Ok(Err(_)) = result { + todo!() + } + } + + // Now that the push threads have shut down, we can stop the updater + updater.stop(); + + let result = updaterd.await; + + if let Err(_) = result { + todo!() + } + + if let Ok(Err(_)) = result { + todo!() } Ok(()) diff --git a/src/push/tests.rs b/src/push/tests.rs index e8a7fa56..68ce72f6 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,90 +1,22 @@ -use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::Instant; -use anyhow::{Result, anyhow}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use hmac::{Hmac, Mac}; -use prost::Message; -use sentry_protos::taskbroker::v1::TaskActivation; -use sha2::Sha256; use tokio::sync::Notify; use tokio::time::{Duration, timeout}; use crate::config::Config; +use crate::push::Pusher; +use crate::push::updater::test_eager_updater; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::FailedTasksForwarder; -use crate::test_utils::{create_test_store, make_activations}; -use crate::worker::{WorkerClient, WorkerMap}; +use crate::test_utils::make_activations; +use crate::worker::test_worker_map; use super::PushPool; -/// Fake worker client that records pushed activation IDs and optionally fails. -struct MockWorkerClient { - /// Pushed activation IDs. - pushed: Vec, - - /// Should `push_task` fail? - fail: bool, -} - -impl MockWorkerClient { - fn new(fail: bool) -> Self { - Self { - pushed: vec![], - fail, - } - } -} - -#[async_trait] -impl WorkerClient for MockWorkerClient { - async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { - TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; - self.pushed.push(activation.id); - - if self.fail { - return Err(anyhow!("mock send failure")); - } - - Ok(()) - } -} - -/// Fake worker client that fires a `Notify` when `push_task` is called. -struct NotifyingWorkerClient { - /// Fire off notification when `push_task` is called - notify: Arc, - - /// Should `push_task` fail? - fail: bool, -} - -#[async_trait] -impl WorkerClient for NotifyingWorkerClient { - async fn push_task(&mut self, _activation: InflightActivation) -> Result<()> { - self.notify.notify_one(); - - if self.fail { - return Err(anyhow!("mock send failure")); - } - - Ok(()) - } -} - -/// Create a map of notifying worker clients for tests. -fn test_worker_map(fail: bool, notify: Arc) -> WorkerMap { - let mut workers = HashMap::new(); - - let client = NotifyingWorkerClient { fail, notify }; - - workers.insert("sentry".into(), Box::new(client) as Box); - workers -} - /// Minimal fake store that records which activation IDs have been marked as processing. #[derive(Clone)] struct MockStore { @@ -100,7 +32,7 @@ impl Default for MockStore { } impl MockStore { - fn marked_ids(&self) -> Vec { + fn marked(&self) -> Vec { self.marked_processing.lock().unwrap().clone() } } @@ -245,8 +177,6 @@ impl InflightActivationStore for MockStore { } } -// --- PushPool tests --- - #[tokio::test] async fn push_pool_push_task_enqueues_item() { let config = Arc::new(Config { @@ -254,10 +184,10 @@ async fn push_pool_push_task_enqueues_item() { ..Config::default() }); - let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); let activation = make_activations(1).remove(0); + let pool = PushPool::new(config); + let time = Instant::now(); let result = pool.push_task(activation, time).await; assert!(result.is_ok(), "push_task should enqueue activation"); @@ -270,16 +200,15 @@ async fn push_pool_push_task_backpressures_when_queue_full() { ..Config::default() }); - let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); - - let time = Instant::now(); let first = make_activations(1).remove(0); let second = make_activations(1).remove(0); + let pool = PushPool::new(config); + + let time = Instant::now(); pool.push_task(first, time) .await - .expect("first push_task should fill queue"); + .expect("first task should fill the queue"); let second_push = timeout(Duration::from_millis(50), pool.push_task(second, time)).await; assert!( @@ -297,12 +226,17 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new(config, store.clone())); + let pool = Arc::new(PushPool::new(config)); let workers = vec![test_worker_map(false, notify.clone())]; + let updater = test_eager_updater(store.clone()); + let pool_start = pool.clone(); tokio::spawn(async move { - pool_start.start(workers).await.expect("push pool start"); + pool_start + .start(workers, updater) + .await + .expect("push pool start"); }); let activation = make_activations(1).remove(0); @@ -321,7 +255,7 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { tokio::time::sleep(Duration::from_millis(50)).await; assert_eq!( - store.marked_ids(), + store.marked(), vec![id], "mark_processing should be called after a successful first-attempt push" ); @@ -336,12 +270,17 @@ async fn push_pool_start_marks_activation_processing_on_retry() { ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new(config, store.clone())); + let pool = Arc::new(PushPool::new(config)); let workers = vec![test_worker_map(false, notify.clone())]; + let updater = test_eager_updater(store.clone()); + let pool_start = pool.clone(); tokio::spawn(async move { - pool_start.start(workers).await.expect("push pool start"); + pool_start + .start(workers, updater) + .await + .expect("push pool start"); }); let mut activation = make_activations(1).remove(0); @@ -360,7 +299,7 @@ async fn push_pool_start_marks_activation_processing_on_retry() { tokio::time::sleep(Duration::from_millis(50)).await; assert_eq!( - store.marked_ids(), + store.marked(), vec![id], "mark_processing should be called after a successful retry push" ); @@ -375,12 +314,17 @@ async fn push_pool_start_does_not_mark_processing_on_push_failure() { ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new(config, store.clone())); + let pool = Arc::new(PushPool::new(config)); let workers = vec![test_worker_map(true, notify.clone())]; + let updater = test_eager_updater(store.clone()); + let pool_start = pool.clone(); tokio::spawn(async move { - pool_start.start(workers).await.expect("push pool start"); + pool_start + .start(workers, updater) + .await + .expect("push pool start"); }); let activation = make_activations(1).remove(0); @@ -396,7 +340,7 @@ async fn push_pool_start_does_not_mark_processing_on_push_failure() { tokio::time::sleep(Duration::from_millis(50)).await; assert!( - store.marked_ids().is_empty(), + store.marked().is_empty(), "mark_processing should not be called when push fails" ); } diff --git a/src/push/thread.rs b/src/push/thread.rs index b517375e..e5b1304c 100644 --- a/src/push/thread.rs +++ b/src/push/thread.rs @@ -5,10 +5,8 @@ use anyhow::Result; use elegant_departure::get_shutdown_guard; use flume::Receiver; -use crate::config::Config; use crate::push::updater::Updater; use crate::store::activation::InflightActivation; -use crate::store::traits::InflightActivationStore; use crate::worker::WorkerMap; /// Alias for ergonomics. @@ -16,12 +14,6 @@ pub type Submission = (InflightActivation, Instant); /// Abstraction for a single push thread. pub struct PushThread { - /// The taskbroker configuration. - pub(super) config: Arc, - - /// The activation store. - pub(super) store: Arc, - /// Maps every application to its worker service. pub(super) workers: WorkerMap, @@ -51,7 +43,9 @@ impl PushThread { metrics::histogram!("push.queue.latency").record(time.elapsed()); // Push the task and mark it as processing - self.push_task(activation).await; + if let Err(_) = self.push_task(activation).await { + todo!() + } } } } @@ -63,7 +57,9 @@ impl PushThread { metrics::histogram!("push.queue.latency").record(time.elapsed()); // Push the task and mark it as processing - self.push_task(activation).await; + if let Err(_) = self.push_task(activation).await { + todo!() + } } drop(guard); @@ -77,11 +73,13 @@ impl PushThread { // First, determine the correct worker service let Some(worker) = self.workers.get_mut(&activation.application) else { // Missing application to worker mapping - return Ok(()); + todo!() }; // Then, push the task to that service - worker.push_task(activation).await?; + if let Err(__) = worker.push_task(activation).await { + todo!() + } // Finally, mark the activation as processing self.updater.update(id).await diff --git a/src/push/updater.rs b/src/push/updater.rs index 159e6bc5..69cff54c 100644 --- a/src/push/updater.rs +++ b/src/push/updater.rs @@ -1,21 +1,33 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; +use std::time::Duration; use anyhow::Result; use chrono::{DateTime, Utc}; use elegant_departure::get_shutdown_guard; -use tokio::sync::{Mutex, MutexGuard}; +use tokio::sync::{Mutex, MutexGuard, Notify}; use tonic::async_trait; use tracing::{info, warn}; -use crate::{config::Config, store::traits::InflightActivationStore}; +use crate::config::Config; +use crate::store::traits::InflightActivationStore; +/// Represents an entity that can update tasks in some way. Meant to abstract away +/// the update logic, which varies between delivery modes, batching configurations, and so on. #[async_trait] pub trait Updater: Send + Sync { - async fn start(&self) -> Result<()>; + /// Start the updater. Useful for updaters that run a background task. + async fn start(&self) -> Result<()> { + Ok(()) + } + /// Update activation in some way given its ID. async fn update(&self, id: String) -> Result<()>; + + /// Stop the updater. Useful for updaters that run a background task. + fn stop(&self) {} } +/// Used by push threads to update sent activations from "claimed" to "processing" in batches. pub struct LazyUpdater { /// The taskbroker configuration. config: Arc, @@ -28,11 +40,56 @@ pub struct LazyUpdater { /// Sent activations that need to be updated. buffer: Arc>>, + + /// Signals the background task to stop. + stop: Notify, +} + +impl LazyUpdater { + pub fn new(config: Arc, store: Arc) -> Self { + let buffer = Arc::new(Mutex::new(vec![])); + let last_flush = Utc::now(); + let stop = Notify::new(); + + Self { + config, + store, + buffer, + last_flush, + stop, + } + } + + /// Flush buffered activations to the store. Empties the buffer on success, refills on failure. + async fn flush(&self, buffer: &mut MutexGuard<'_, Vec>) -> Result<()> { + let ids: Vec<_> = buffer.drain(..).collect(); + let expected = ids.len() as u64; + + match self.store.mark_processing_batch(&ids).await { + Ok(actual) => { + if actual < expected { + // This may happen if tasks are reverted back to pending OR completed too quickly + warn!( + "Push thread update batch contained {expected} records, but only {actual} were updated" + ); + } + + Ok(()) + } + + Err(e) => { + // Flush failed, return IDs to buffer + buffer.extend(ids); + Err(e) + } + } + } } #[async_trait] impl Updater for LazyUpdater { async fn start(&self) -> Result<()> { + // Hold guard until the updater has stopped to delay shutdown let guard = get_shutdown_guard(); // Flush every `interval` milliseconds @@ -41,8 +98,8 @@ impl Updater for LazyUpdater { loop { tokio::select! { - _ = guard.wait() => { - info!("Shutting down batched updater..."); + _ = self.stop.notified() => { + info!("Stopping lazy updater..."); break; } @@ -60,11 +117,14 @@ impl Updater for LazyUpdater { } // We can propagate the error upwards here if desired - self.flush(&mut buffer).await; + if let Err(_) = self.flush(&mut buffer).await { + todo!() + } } } } + drop(guard); Ok(()) } @@ -74,66 +134,41 @@ impl Updater for LazyUpdater { if buffer.len() >= self.config.push_update_batch_size { // Flush first - self.flush(&mut buffer).await?; + if let Err(_) = self.flush(&mut buffer).await { + todo!() + } } buffer.push(id); Ok(()) } -} - -impl LazyUpdater { - pub fn new(config: Arc, store: Arc) -> Self { - let buffer = Arc::new(Mutex::new(vec![])); - let last_flush = Utc::now(); - Self { - config, - store, - buffer, - last_flush, - } - } - - /// Flush buffered activations to the store. Empties the buffer on success, refills on failure. - async fn flush(&self, buffer: &mut MutexGuard<'_, Vec>) -> Result<()> { - let ids: Vec<_> = buffer.drain(..).collect(); - let expected = ids.len() as u64; - - match self.store.mark_processing_batch(&ids).await { - Ok(actual) => { - if actual < expected { - // This may happen if tasks are reverted back to pending OR completed too quickly - warn!( - "Push thread update batch contained {expected} records, but only {actual} were updated" - ); - } - - Ok(()) - } - - Err(e) => { - // Flush failed, return IDs to buffer - buffer.extend(ids); - Err(e) - } - } + fn stop(&self) { + self.stop.notify_one(); } } +/// Used by push threads to update sent activations from "claimed" to "processing" individually. pub struct EagerUpdater { /// The activation store. store: Arc, } -#[async_trait] -impl Updater for EagerUpdater { - async fn start(&self) -> Result<()> { - // There is nothing to do in the background, so we can return immediately - todo!() +impl EagerUpdater { + pub fn new(store: Arc) -> Self { + Self { store } } +} +#[async_trait] +impl Updater for EagerUpdater { async fn update(&self, id: String) -> Result<()> { self.store.mark_processing(&id).await } } + +#[cfg(test)] +pub fn test_eager_updater(store: Arc) -> Arc { + let eager = EagerUpdater::new(store); + Arc::new(eager) as Arc +} diff --git a/src/worker.rs b/src/worker.rs index 7f60f0d1..a6409306 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,17 +1,21 @@ -use std::{collections::HashMap, time::Duration}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; use anyhow::{Context, Result, anyhow}; use async_backtrace::framed; use hmac::{Hmac, Mac}; use prost::Message; -use sentry_protos::taskbroker::v1::{ - PushTaskRequest, TaskActivation, worker_service_client::WorkerServiceClient, -}; +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; use sha2::Sha256; +#[cfg(test)] +use tokio::sync::Notify; use tonic::metadata::MetadataValue; use tonic::transport::Channel; use tonic::{Request, async_trait}; +use crate::config::Config; use crate::store::activation::InflightActivation; // Alias for ergonomics. @@ -20,7 +24,7 @@ pub type WorkerMap = HashMap>; /// gRPC path for `WorkerService::PushTask` that should be kept in sync with `sentry_protos` generated client. pub const WORKER_PUSH_TASK_PATH: &str = "/sentry_protos.taskbroker.v1.WorkerService/PushTask"; -/// HMAC-SHA256(secret, grpc_path + ":" + message), hex-encoded. Matches Python `RequestSignatureInterceptor` and broker [`crate::grpc::auth_middleware`]. +/// Helper to compute HMAC-SHA256(secret, grpc_path + ":" + message) in hexadecimal. Matches Python `RequestSignatureInterceptor`. fn sentry_signature_hex(secret: &str, grpc_path: &str, message: &[u8]) -> String { let mut mac = Hmac::::new_from_slice(secret.as_bytes()).expect("HMAC accepts keys of any length"); @@ -50,6 +54,21 @@ pub struct Worker { timeout: Duration, } +impl Worker { + pub async fn connect(config: Arc, endpoint: String) -> Result { + let client = WorkerServiceClient::connect(endpoint).await?; + + let secrets = config.grpc_shared_secret.clone(); + let timeout = Duration::from_millis(config.push_timeout_ms); + + Ok(Self { + client, + secrets, + timeout, + }) + } +} + #[async_trait] impl WorkerClient for Worker { #[framed] @@ -85,11 +104,86 @@ impl WorkerClient for Worker { } } +/// Fake worker client that records pushed activation IDs and optionally fails. +#[cfg(test)] +struct MockWorkerClient { + /// Pushed activation IDs. + pushed: Vec, + + /// Should `push_task` fail? + fail: bool, +} + +#[cfg(test)] +impl MockWorkerClient { + fn new(fail: bool) -> Self { + Self { + pushed: vec![], + fail, + } + } +} + +#[cfg(test)] +#[async_trait] +impl WorkerClient for MockWorkerClient { + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; + self.pushed.push(activation.id); + + if self.fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } +} + +/// Fake worker client that fires a `Notify` when `push_task` is called. +#[cfg(test)] +struct NotifyingWorkerClient { + /// Fire off notification when `push_task` is called + notify: Arc, + + /// Should `push_task` fail? + fail: bool, +} + +#[cfg(test)] +#[async_trait] +impl WorkerClient for NotifyingWorkerClient { + async fn push_task(&mut self, _activation: InflightActivation) -> Result<()> { + self.notify.notify_one(); + + if self.fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } +} + +/// Create a map of notifying worker clients for tests. +#[cfg(test)] +pub fn test_worker_map(fail: bool, notify: Arc) -> WorkerMap { + let mut workers = HashMap::new(); + + let client = NotifyingWorkerClient { fail, notify }; + + workers.insert("sentry".into(), Box::new(client) as Box); + workers +} + +#[cfg(test)] mod tests { - use hmac::Hmac; + use hmac::{Hmac, Mac}; use sha2::Sha256; - use crate::{test_utils::make_activations, worker::WORKER_PUSH_TASK_PATH}; + use crate::test_utils::make_activations; + + use super::MockWorkerClient; + use super::WORKER_PUSH_TASK_PATH; + use super::WorkerClient; #[tokio::test] async fn worker_push_task_returns_ok_on_client_success() { From a1456556b682fe23c8cc12a73a668ab634f89cc4 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Wed, 20 May 2026 16:15:21 -0700 Subject: [PATCH 19/19] Undo Unneeded Changes --- src/fetch/mod.rs | 10 ++--- src/fetch/tests.rs | 2 +- src/push/mod.rs | 4 +- src/push/tests.rs | 2 +- src/worker/tests.rs | 104 -------------------------------------------- 5 files changed, 9 insertions(+), 113 deletions(-) delete mode 100644 src/worker/tests.rs diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index eeedaf2c..7aeec718 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -10,7 +10,7 @@ use tokio::time::sleep; use tracing::{debug, info, warn}; use crate::config::Config; -use crate::push::Pusher; +use crate::push::TaskPusher; use crate::store::traits::InflightActivationStore; use crate::store::types::BucketRange; @@ -47,23 +47,23 @@ pub fn bucket_range_for_fetch_thread(thread_index: usize, fetch_threads: usize) } /// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches batches of pending activations from the store, passes them to the push pool, and repeats. -pub struct FetchPool { +pub struct FetchPool { /// Inflight activation store. store: Arc, /// Pool of push threads that push activations to the worker service. - pusher: Arc

, + pusher: Arc, /// Taskbroker configuration. config: Arc, } -impl FetchPool

{ +impl FetchPool { /// Initialize a new fetch pool. pub fn new( store: Arc, config: Arc, - pusher: Arc

, + pusher: Arc, ) -> Self { Self { store, diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 444f6bfa..9e5a0d49 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -207,7 +207,7 @@ impl RecordingPusher { } #[async_trait] -impl Pusher for RecordingPusher { +impl TaskPusher for RecordingPusher { async fn push_task(&self, activation: InflightActivation, _time: Instant) -> Result<()> { self.pushed_ids.lock().await.push(activation.id.clone()); diff --git a/src/push/mod.rs b/src/push/mod.rs index c196684a..45cbec90 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -41,7 +41,7 @@ pub fn compute_claim_lease_ms(config: &Config) -> u64 { /// Thin interface for the push pool. It mostly serves to enable proper unit testing, /// but it also decouples fetch logic from push logic even further. #[async_trait] -pub trait Pusher { +pub trait TaskPusher { /// Submit a single task to the push pool. async fn push_task(&self, activation: InflightActivation, time: Instant) -> Result<()>; } @@ -120,7 +120,7 @@ impl PushPool { } #[async_trait] -impl Pusher for PushPool { +impl TaskPusher for PushPool { #[framed] async fn push_task(&self, activation: InflightActivation, time: Instant) -> Result<()> { let duration = Duration::from_millis(self.config.push_queue_timeout_ms); diff --git a/src/push/tests.rs b/src/push/tests.rs index 68ce72f6..defc93e8 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -7,7 +7,7 @@ use tokio::sync::Notify; use tokio::time::{Duration, timeout}; use crate::config::Config; -use crate::push::Pusher; +use crate::push::TaskPusher; use crate::push::updater::test_eager_updater; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; diff --git a/src/worker/tests.rs b/src/worker/tests.rs deleted file mode 100644 index edd7d872..00000000 --- a/src/worker/tests.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::time::Duration; - -use anyhow::anyhow; -use async_trait::async_trait; -use prost::Message; -use sentry_protos::taskbroker::v1::TaskActivation; - -use crate::store::activation::InflightActivation; -use crate::test_utils::make_activations; -use crate::worker::{Worker, WorkerClient, sentry_signature_hex}; - -/// Fake worker client that records task IDs and optionally fails. -struct MockWorkerClient { - captured_task_ids: Vec, - should_fail: bool, -} - -impl MockWorkerClient { - fn new(should_fail: bool) -> Self { - Self { - captured_task_ids: vec![], - should_fail, - } - } -} - -#[async_trait] -impl WorkerClient for MockWorkerClient { - async fn push_task(&mut self, activation: InflightActivation) -> anyhow::Result<()> { - let task = - TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; - self.captured_task_ids.push(task.id); - - if self.should_fail { - return Err(anyhow!("mock send failure")); - } - - Ok(()) - } -} - -#[tokio::test] -async fn push_task_returns_ok_on_client_success() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(false); - - let result = worker.push_task(activation.clone()).await; - assert!(result.is_ok(), "push_task should succeed"); - assert_eq!(worker.captured_task_ids, vec![activation.id]); -} - -#[tokio::test] -async fn push_task_returns_err_on_invalid_payload() { - let mut activation = make_activations(1).remove(0); - activation.activation = vec![1, 2, 3, 4]; - - let mut worker = MockWorkerClient::new(false); - let result = worker.push_task(activation).await; - - assert!(result.is_err(), "invalid payload should fail decoding"); - assert!( - worker.captured_task_ids.is_empty(), - "worker should not record a task id if decode fails" - ); -} - -#[tokio::test] -async fn push_task_propagates_client_error() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(true); - - let result = worker.push_task(activation.clone()).await; - assert!(result.is_err(), "worker send errors should propagate"); - assert_eq!(worker.captured_task_ids, vec![activation.id]); -} - -#[test] -fn sentry_signature_hex_matches_hmac_contract() { - let digest = sentry_signature_hex("super secret", "/test/path", b"hello"); - assert_eq!( - digest, - "6408482d9e6d4975ada4c0302fda813c5718e571e6f9a2d6e2803cb48528044e" - ); -} - -#[tokio::test] -async fn worker_connect_fails_for_unreachable_endpoint() { - let result = Worker::connect("http://127.0.0.1:1".into()).await; - assert!( - result.is_err(), - "connect should return Err for an unreachable endpoint" - ); -} - -#[tokio::test] -async fn worker_connect_with_options_fails_for_unreachable_endpoint() { - let result = - Worker::connect_with_options("http://127.0.0.1:1".into(), vec![], Duration::from_secs(1)) - .await; - assert!( - result.is_err(), - "connect_with_options should return Err for an unreachable endpoint" - ); -}