diff --git a/src/config.rs b/src/config.rs index fd9b11e8..40b0f416 100644 --- a/src/config.rs +++ b/src/config.rs @@ -31,6 +31,12 @@ pub enum DeliveryMode { Push, } +impl DeliveryMode { + pub fn is_push(self) -> bool { + self == DeliveryMode::Push + } +} + #[derive(PartialEq, Debug, Deserialize, Serialize)] pub struct Config { /// The sentry DSN to use for error reporting. @@ -222,6 +228,11 @@ pub struct Config { /// brokers are under load, or there are small networking delays. pub processing_deadline_grace_sec: u64, + /// The number of additional seconds that claim expirations + /// are extended by. This helps reduce claim expirations when + /// brokers are under load, or there are small networking delays. + pub claim_expiration_grace_sec: u64, + /// The frequency at which upkeep tasks /// (discarding, retrying activations, etc.) are executed. pub upkeep_task_interval_ms: u64, @@ -299,7 +310,7 @@ pub struct Config { /// Maximum time in milliseconds for a single push RPC to the worker service. This should be greater than the worker's internal timeout. pub push_timeout_ms: u64, - /// Update statuses from the gRPC server in batches? + /// Update statuses from the gRPC server in batches? Only applies in PUSH mode. pub batch_status_updates: bool, /// The size of a batch of status updates. @@ -308,10 +319,19 @@ pub struct Config { /// Maximum milliseconds to wait before flushing a batch of status updates. pub status_update_interval_ms: u64, - /// The hostname used to construct `callback_url` for task push requests. + /// Update claimed → processing (dispatch) updates in batches? Only applies in PUSH mode. + pub batch_push_updates: bool, + + /// The size of a batch of dispatch updates. + pub push_update_batch_size: usize, + + /// Maximum milliseconds to wait before flushing a batch of dispatch updates. + pub push_update_interval_ms: u32, + + /// (DEPRECATED) The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, - /// The port used to construct `callback_url` for task push requests. + /// (DEPRECATED) The port used to construct `callback_url` for task push requests. pub callback_port: u32, /// Maps every application to its worker endpoint, both represented as strings. @@ -397,6 +417,7 @@ impl Default for Config { max_processing_count: 2048, max_processing_attempts: 5, processing_deadline_grace_sec: 3, + claim_expiration_grace_sec: 3, upkeep_task_interval_ms: 1000, upkeep_unhealthy_interval_ms: 5000, health_check_killswitched: false, @@ -421,6 +442,9 @@ impl Default for Config { batch_status_updates: false, status_update_batch_size: 1, status_update_interval_ms: 100, + batch_push_updates: false, + push_update_batch_size: 1, + push_update_interval_ms: 100, callback_addr: "0.0.0.0".into(), callback_port: 50051, worker_map: [("sentry".into(), "http://127.0.0.1:50052".into())].into(), diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index e4ae2fab..7aeec718 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -7,12 +7,10 @@ use async_backtrace::framed; use chrono::Utc; use elegant_departure::get_shutdown_guard; use tokio::time::sleep; -use tonic::async_trait; use tracing::{debug, info, warn}; use crate::config::Config; -use crate::push::{PushError, PushPool}; -use crate::store::activation::InflightActivation; +use crate::push::TaskPusher; use crate::store::traits::InflightActivationStore; use crate::store::types::BucketRange; @@ -48,29 +46,6 @@ pub fn bucket_range_for_fetch_thread(thread_index: usize, fetch_threads: usize) (low, high) } -/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. -#[async_trait] -pub trait TaskPusher { - /// Submit a single task to the push pool. - async fn submit_task( - &self, - activation: InflightActivation, - time: Instant, - ) -> Result<(), PushError>; -} - -#[async_trait] -impl TaskPusher for PushPool { - #[framed] - async fn submit_task( - &self, - activation: InflightActivation, - time: Instant, - ) -> Result<(), PushError> { - self.submit(activation, time).await - } -} - /// Wrapper around `config.fetch_threads` asynchronous tasks, each of which fetches batches of pending activations from the store, passes them to the push pool, and repeats. pub struct FetchPool { /// Inflight activation store. @@ -122,11 +97,10 @@ impl FetchPool { } _ = async { - let start = Instant::now(); - debug!("Fetching next batch of pending activations..."); metrics::counter!("fetch.loop.count").increment(1); + let start = Instant::now(); let mut backoff = false; let result = store.claim_activations_for_push(limit, bucket).await; @@ -165,13 +139,19 @@ impl FetchPool { ) .record(latency as f64); } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → claimed latency recording"); + debug!( + task_id = %id, + namespace = activation.namespace, + taskname = activation.taskname, + "Activation already processed, skipping \ + received → claimed latency recording" + ); } - match pusher.submit_task(activation, start).await { + match pusher.push_task(activation, start).await { Ok(()) => metrics::counter!("fetch.submit", "result" => "ok").increment(1), - Err(PushError::Timeout) => { + Err(_) => { metrics::counter!("fetch.submit", "result" => "timeout") .increment(1); @@ -184,20 +164,6 @@ impl FetchPool { // Wait for push queue to empty backoff = true; } - - Err(PushError::Channel(e)) => { - metrics::counter!("fetch.submit", "result" => "channel_error") - .increment(1); - - warn!( - task_id = %id, - error = ?e, - "Submit to push pool failed due to channel error", - ); - - // Wait before trying again - backoff = true; - } } } } diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 092d3503..9e5a0d49 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -7,7 +7,6 @@ use tokio::time::{Duration, sleep}; use tonic::async_trait; use crate::config::Config; -use crate::push::PushError; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::{BucketRange, FailedTasksForwarder}; @@ -98,10 +97,14 @@ impl InflightActivationStore for MockStore { }) } - async fn mark_activation_processing(&self, _id: &str) -> Result<(), Error> { + async fn mark_processing(&self, _id: &str) -> Result<(), Error> { Ok(()) } + async fn mark_processing_batch(&self, _ids: &[String]) -> Result { + unimplemented!() + } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { unimplemented!() } @@ -205,15 +208,11 @@ impl RecordingPusher { #[async_trait] impl TaskPusher for RecordingPusher { - async fn submit_task( - &self, - activation: InflightActivation, - _time: Instant, - ) -> Result<(), PushError> { + async fn push_task(&self, activation: InflightActivation, _time: Instant) -> Result<()> { self.pushed_ids.lock().await.push(activation.id.clone()); if self.fail { - return Err(PushError::Timeout); + return Err(anyhow!("timeout")); } Ok(()) diff --git a/src/flusher.rs b/src/flusher.rs index 33732dfa..dbfd32d2 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -3,7 +3,9 @@ use std::pin::Pin; use std::time::Duration; use anyhow::Result; +use elegant_departure::get_shutdown_guard; use tokio::sync::mpsc::Receiver; +use tracing::debug; /// Run flusher that receives values of type T from a channel and flushes /// them using the provided async `flush` function either when the batch is @@ -26,10 +28,16 @@ where interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut buffer: Vec = Vec::with_capacity(batch_size); + let guard = get_shutdown_guard(); loop { tokio::select! { - msg = rx.recv() => { + biased; + + // When the buffer is NOT full, try to receive another message + msg = rx.recv(), if buffer.len() < batch_size => { + debug!("Buffer is NOT full, receiving a message..."); + match msg { Some(v) => { buffer.push(v); @@ -39,25 +47,42 @@ where } if buffer.len() >= batch_size { + debug!("Flushing full buffer..."); flush(&mut buffer).await; } } None => { - // Channel closed (shutdown), flush remaining and exit - flush(&mut buffer).await; + // Channel closed because all senders were dropped + debug!("Channel closed!"); break; } } } + // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds _ = interval.tick() => { - if !buffer.is_empty() { - flush(&mut buffer).await; + debug!("Performing periodic flush..."); + + if rx.is_closed() { + // Channel closed because all senders were dropped + debug!("Channel closed on tick!"); + break; } + + flush(&mut buffer).await; } } } + // Drain and flush before exit + while let Ok(update) = rx.try_recv() { + buffer.push(update); + } + + // Delay shutdown until we have flushed everything in the buffer + flush(&mut buffer).await; + drop(guard); + Ok(()) } diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 9d669c5e..3f68de7d 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -107,6 +107,9 @@ impl ConsumerService for TaskbrokerServer { } if let Some(ref tx) = self.update_tx { + let depth = tx.max_capacity() - tx.capacity(); + metrics::gauge!("grpc_server.update_queue.depth").set(depth as f64); + tx.send((id, status)) .await .map_err(|_| Status::internal("Status update channel closed"))?; diff --git a/src/lib.rs b/src/lib.rs index 89a17421..33fe2919 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub mod store; pub mod test_utils; pub mod tokio; pub mod upkeep; +pub mod worker; /// Name of the grpc service. /// Using the service type to get a name wasn't working across modules. diff --git a/src/main.rs b/src/main.rs index c3648a6d..fd6d220e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -12,11 +13,11 @@ use tonic::transport::Server; use tonic_health::ServingStatus; use tracing::{debug, error, info, warn}; -use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; +use taskbroker::config::{Config, DatabaseAdapter}; use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; -use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; +use taskbroker::grpc::server::TaskbrokerServer; use taskbroker::kafka::admin::create_missing_topics; use taskbroker::kafka::consumer::start_consumer; use taskbroker::kafka::deserialize::{self, DeserializeConfig}; @@ -27,10 +28,10 @@ use taskbroker::kafka::inflight_activation_writer::{ ActivationWriterConfig, InflightActivationWriter, }; use taskbroker::kafka::os_stream_writer::{OsStream, OsStreamWriter}; -use taskbroker::logging; use taskbroker::metrics; use taskbroker::processing_strategy; use taskbroker::push::PushPool; +use taskbroker::push::updater::{EagerUpdater, LazyUpdater, Updater}; use taskbroker::runtime_config::RuntimeConfigManager; use taskbroker::store::adapters::postgres::{ PostgresActivationStore, PostgresActivationStoreConfig, @@ -38,8 +39,10 @@ use taskbroker::store::adapters::postgres::{ use taskbroker::store::adapters::sqlite::{InflightActivationStoreConfig, SqliteActivationStore}; use taskbroker::store::traits::InflightActivationStore; use taskbroker::upkeep::upkeep; +use taskbroker::worker::{Worker, WorkerClient, WorkerMap}; use taskbroker::{Args, get_version}; use taskbroker::{SERVICE_NAME, flusher}; +use taskbroker::{grpc, logging}; async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { @@ -192,7 +195,9 @@ async fn main() -> Result<(), Error> { }); // Status update flush task - let (status_update_tx, status_update_task) = if config.batch_status_updates { + let (status_update_tx, status_update_task) = if config.batch_status_updates + && config.delivery_mode.is_push() + { let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size.max(1)); let flusher_store = store.clone(); @@ -203,7 +208,7 @@ async fn main() -> Result<(), Error> { rx, flusher_config.status_update_batch_size, flusher_config.status_update_interval_ms, - move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), + move |buffer| Box::pin(grpc::server::flush_updates(flusher_store.clone(), buffer)), ) .await }); @@ -217,7 +222,6 @@ async fn main() -> Result<(), Error> { let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); let grpc_config = config.clone(); - let grpc_status_tx = status_update_tx.clone(); async move { let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) @@ -234,7 +238,7 @@ async fn main() -> Result<(), Error> { .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, config: grpc_config, - update_tx: grpc_status_tx, + update_tx: status_update_tx, })) .add_service(health_service.clone()) .serve(addr); @@ -266,18 +270,54 @@ async fn main() -> Result<(), Error> { }); // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(config.clone(), store.clone())); + let push_pool = Arc::new(PushPool::new(config.clone())); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads - let push_task = if config.delivery_mode == DeliveryMode::Push { - Some(tokio::spawn(async move { push_pool.start().await })) + let push_task = if config.delivery_mode.is_push() { + let mut workers: Vec = vec![]; + + // For every push thread, create a map from applications to worker connections + for _ in 0..config.push_threads { + let mut map = HashMap::new(); + + for (application, endpoint) in config.worker_map.clone() { + let worker = match Worker::connect(config.clone(), endpoint).await { + Ok(w) => { + debug!("Connected to worker!"); + Box::new(w) as Box + } + + Err(e) => { + error!(error = ?e, "Failed to connect to worker"); + return Err(e); + } + }; + + map.insert(application, worker); + } + + workers.push(map); + } + + // Create the correct kind of push updater + let updater = if config.batch_push_updates { + let lazy = LazyUpdater::new(config.clone(), store.clone()); + Arc::new(lazy) as Arc + } else { + let eager = EagerUpdater::new(store.clone()); + Arc::new(eager) as Arc + }; + + Some(tokio::spawn(async move { + push_pool.start(workers, updater).await + })) } else { None }; // Initialize fetch threads - let fetch_task = if config.delivery_mode == DeliveryMode::Push { + let fetch_task = if config.delivery_mode.is_push() { Some(tokio::spawn(async move { fetch_pool.start().await })) } else { None diff --git a/src/push/mod.rs b/src/push/mod.rs index fbf04018..45cbec90 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,95 +1,49 @@ -use chrono::Utc; -use std::cmp::max; -use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, Instant}; -use anyhow::{Context, Result}; +use anyhow::Result; use async_backtrace::framed; -use elegant_departure::get_shutdown_guard; -use flume::{Receiver, SendError, Sender}; -use hmac::{Hmac, Mac}; -use prost::Message; -use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; -use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; -use sha2::Sha256; +use flume::{Receiver, Sender}; use tokio::task::JoinSet; use tonic::async_trait; -use tonic::metadata::MetadataValue; -use tonic::transport::Channel; -use tracing::{debug, error, info}; use crate::config::Config; +use crate::push::thread::PushThread; +use crate::push::updater::Updater; use crate::store::activation::InflightActivation; -use crate::store::traits::InflightActivationStore; +use crate::worker::WorkerMap; -type HmacSha256 = Hmac; +pub mod thread; +pub mod updater; -type WorkerFactory = Arc< - dyn Fn(String) -> Pin>> + Send>> - + Send - + Sync, ->; +// Helper to compute the longest number of milliseconds an activation may be claimed. +pub fn compute_claim_lease_ms(config: &Config) -> u64 { + // In the worst case, every activation in the batch will time out when appending to the push queue + let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; -/// gRPC path for `WorkerService::PushTask` — keep in sync with `sentry_protos` generated client. -const WORKER_PUSH_TASK_PATH: &str = "/sentry_protos.taskbroker.v1.WorkerService/PushTask"; + // In the worst case, every activation in the push queue will time out when sending + let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; -/// HMAC-SHA256(secret, grpc_path + ":" + message), hex-encoded. Matches Python `RequestSignatureInterceptor` and broker [`crate::grpc::auth_middleware`]. -fn sentry_signature_hex(secret: &str, grpc_path: &str, message: &[u8]) -> String { - let mut mac = - HmacSha256::new_from_slice(secret.as_bytes()).expect("HMAC accepts keys of any length"); - mac.update(grpc_path.as_bytes()); - mac.update(b":"); - mac.update(message); - hex::encode(mac.finalize().into_bytes()) -} + let update_ms = if config.batch_push_updates { + // In the worst case, we will need to wait an entire interval before flushing a batch of push updates + config.push_update_interval_ms as u64 + } else { + // Grace seconds will cover the update query duration until we decide to implement query timeouts + 0 + }; -/// Error returned when enqueueing an activation for the push workers fails. -#[derive(Debug)] -#[allow(clippy::large_enum_variant)] -pub enum PushError { - /// The bounded queue stayed full until `push_queue_timeout_ms` elapsed. - Timeout, + // Account for grace seconds specified in configuration + let grace_ms = config.claim_expiration_grace_sec * 1000; - /// Channel disconnected (no receivers) or another failure. - Channel(SendError<(InflightActivation, Instant)>), -} - -/// Thin interface for the worker client. It mostly serves to enable proper unit testing, but it also decouples the actual client implementation from our pushing logic. -#[async_trait] -trait WorkerClient { - /// Send a single `PushTaskRequest` to the worker service. - /// When `grpc_shared_secret` is not empty, it signs the request with `grpc_shared_secret[0]` and sets `sentry-signature` metadata (same scheme as Python pull client and broker `AuthLayer`). - async fn send(&mut self, request: PushTaskRequest, grpc_shared_secret: &[String]) - -> Result<()>; + queue_ms + send_ms + update_ms + grace_ms } +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, +/// but it also decouples fetch logic from push logic even further. #[async_trait] -impl WorkerClient for WorkerServiceClient { - #[framed] - async fn send( - &mut self, - request: PushTaskRequest, - grpc_shared_secret: &[String], - ) -> Result<()> { - let mut req = tonic::Request::new(request); - - if let Some(secret) = grpc_shared_secret.first() { - let body = req.get_ref().encode_to_vec(); - let signature = sentry_signature_hex(secret, WORKER_PUSH_TASK_PATH, &body); - let value = MetadataValue::try_from(signature.as_str()) - .context("sentry-signature metadata value must be valid ASCII")?; - req.metadata_mut().insert("sentry-signature", value); - } - - self.push_task(req) - .await - .map_err(|status| anyhow::anyhow!(status))?; - - Ok(()) - } +pub trait TaskPusher { + /// Submit a single task to the push pool. + async fn push_task(&self, activation: InflightActivation, time: Instant) -> Result<()>; } /// Wrapper around `config.push_threads` asynchronous tasks, each of which receives an activation from the channel, sends it to the worker service, and repeats. @@ -102,325 +56,98 @@ pub struct PushPool { /// Taskbroker configuration. config: Arc, - - /// Activation store, which we need for marking tasks as sent. - store: Arc, - - worker_factory: WorkerFactory, } impl PushPool { /// Initialize a new push pool. - pub fn new(config: Arc, store: Arc) -> Self { - let worker_factory: WorkerFactory = Arc::new(|endpoint: String| { - Box::pin(async move { - let client = WorkerServiceClient::connect(endpoint).await?; - Ok(Box::new(client) as Box) - }) - }); - Self::new_with_factory(config, store, worker_factory) - } - - fn new_with_factory( - config: Arc, - store: Arc, - worker_factory: WorkerFactory, - ) -> Self { + pub fn new(config: Arc) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); + Self { sender, receiver, config, - store, - worker_factory, } } /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. #[framed] - pub async fn start(&self) -> Result<()> { - let store = self.store.clone(); - let worker_factory = self.worker_factory.clone(); - let mut push_pool: JoinSet> = crate::tokio::spawn_pool( - self.config.push_threads, - |_| { - let worker_map = self.config.worker_map.clone(); - let receiver = self.receiver.clone(); - let store = store.clone(); - let worker_factory = worker_factory.clone(); - - let guard = get_shutdown_guard().shutdown_on_drop(); - - let callback_url = format!( - "{}:{}", - self.config.callback_addr, self.config.callback_port - ); - - let timeout = Duration::from_millis(self.config.push_timeout_ms); - let grpc_shared_secret = self.config.grpc_shared_secret.clone(); - - async_backtrace::frame!(async move { - metrics::counter!("push.worker.connect.attempt").increment(1); - - let mut workers: HashMap> = HashMap::new(); - - for (application, endpoint) in worker_map.clone() { - let worker = match worker_factory(endpoint).await { - Ok(w) => { - metrics::counter!("push.worker.connect", "result" => "ok", "application" => application.clone()) - .increment(1); - w - } - - Err(e) => { - metrics::counter!("push.worker.connect", "result" => "error", "application" => application.clone()) - .increment(1); - error!(error = ?e, "Failed to connect to worker"); - - return Err(e); - } - }; - - workers.insert(application, worker); - } - - loop { - tokio::select! { - _ = guard.wait() => { - info!("Push worker received shutdown signal"); - break; - } - - message = receiver.recv_async() => { - let (activation, time) = match message { - // Received activation from fetch thread - Ok(a) => a, - - // Channel closed - Err(_) => break, - }; - - metrics::histogram!("push.queue.latency").record(time.elapsed()); - - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; + pub async fn start(&self, workers: Vec, updater: Arc) -> Result<()> { + let mut workers = workers.into_iter(); - match push_task( - worker.as_mut(), - activation.clone(), - callback_url, - timeout, - grpc_shared_secret.as_slice(), - ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - if activation.processing_attempts < 1 { - let latency = max(0, activation.received_latency(Utc::now())); - - metrics::histogram!( - "push.received_to_push.latency", - "namespace" => activation.namespace, - "taskname" => activation.taskname, - ) - .record(latency as f64); - } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); - } - - let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as sent after push" - ); - } - } - - // Once claim expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; - } - } - } - - // Drain channel before exiting without recording duration metrics since they don't matter at this time - for (activation, _) in receiver.drain() { - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; - - match push_task( - worker.as_mut(), - activation, - callback_url, - timeout, - grpc_shared_secret.as_slice(), - ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); + // Start the updater + let updaterd = tokio::spawn({ + let updater = updater.clone(); + async move { updater.start().await } + }); - let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration") - .record(start.elapsed()); + let mut threads: JoinSet> = + crate::tokio::spawn_pool(self.config.push_threads, |_| { + let mut thread = PushThread { + workers: workers.next().unwrap(), + receiver: self.receiver.clone(), + updater: updater.clone(), + }; - if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); + async move { thread.start().await } + }); - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as processing after push" - ); - } - } + while let Some(result) = threads.join_next().await { + if let Err(_) = result { + todo!() + } - // Once processing deadline expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error") - .increment(1); + if let Ok(Err(_)) = result { + todo!() + } + } - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; - } + // Now that the push threads have shut down, we can stop the updater + updater.stop(); - Ok(()) - }) - }, - ); + let result = updaterd.await; - while let Some(result) = push_pool.join_next().await { - match result { - Ok(r) => { - // Connection failed - r? - } + if let Err(_) = result { + todo!() + } - // Join failed - Err(e) => return Err(e.into()), - } + if let Ok(Err(_)) = result { + todo!() } Ok(()) } +} - /// Send an activation to the internal asynchronous MPMC channel used by all running push threads. Times out after `config.push_queue_timeout_ms` milliseconds. +#[async_trait] +impl TaskPusher for PushPool { #[framed] - pub async fn submit( - &self, - activation: InflightActivation, - time: Instant, - ) -> Result<(), PushError> { + async fn push_task(&self, activation: InflightActivation, time: Instant) -> Result<()> { let duration = Duration::from_millis(self.config.push_queue_timeout_ms); let start = Instant::now(); metrics::gauge!("push.queue.depth").set(self.sender.len() as f64); match tokio::time::timeout(duration, self.sender.send_async((activation, time))).await { - Ok(Ok(())) => { - metrics::histogram!("push.queue.wait_duration").record(start.elapsed()); - Ok(()) + // The channel has closed because all receivers were dropped + Ok(Err(e)) => { + // The only way this can happen is if the push threads have already exited, which is an invalid state + unreachable!("{}", e); } - // The channel has a problem - Ok(Err(e)) => { + // The channel was full so the pushing timed out + Err(e) => { metrics::histogram!("push.queue.wait_duration").record(start.elapsed()); - Err(PushError::Channel(e)) + Err(e.into()) } - // The channel was full so the send timed out - Err(_elapsed) => { + Ok(_) => { metrics::histogram!("push.queue.wait_duration").record(start.elapsed()); - Err(PushError::Timeout) + Ok(()) } } } } -/// Decode task activation and push it to a worker. -#[framed] -async fn push_task( - worker: &mut (dyn WorkerClient + Send), - activation: InflightActivation, - callback_url: String, - timeout: Duration, - grpc_shared_secret: &[String], -) -> Result<()> { - let start = Instant::now(); - metrics::counter!("push.push_task.attempts").increment(1); - - // Try to decode activation (if it fails, we will see the error where `push_task` is called) - let task = match TaskActivation::decode(&activation.activation as &[u8]) { - Ok(task) => task, - Err(err) => { - metrics::histogram!("push.push_task.duration").record(start.elapsed()); - return Err(err.into()); - } - }; - - let request = PushTaskRequest { - task: Some(task), - callback_url, - }; - - let result = match tokio::time::timeout(timeout, worker.send(request, grpc_shared_secret)).await - { - Ok(r) => r, - Err(e) => Err(e.into()), - }; - - metrics::histogram!("push.push_task.duration").record(start.elapsed()); - result -} - #[cfg(test)] mod tests; diff --git a/src/push/tests.rs b/src/push/tests.rs index 888b127d..defc93e8 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,75 +1,38 @@ use std::sync::{Arc, Mutex}; +use std::time::Instant; -use anyhow::anyhow; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use sentry_protos::taskbroker::v1::PushTaskRequest; use tokio::sync::Notify; use tokio::time::{Duration, timeout}; use crate::config::Config; +use crate::push::TaskPusher; +use crate::push::updater::test_eager_updater; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::FailedTasksForwarder; -use crate::test_utils::{create_test_store, make_activations}; +use crate::test_utils::make_activations; +use crate::worker::test_worker_map; -use super::*; +use super::PushPool; -/// Fake worker client that records requests and optionally fails. -struct MockWorkerClient { - captured_requests: Vec, - should_fail: bool, +/// Minimal fake store that records which activation IDs have been marked as processing. +#[derive(Clone)] +struct MockStore { + marked_processing: Arc>>, } -impl MockWorkerClient { - fn new(should_fail: bool) -> Self { +impl Default for MockStore { + fn default() -> Self { Self { - captured_requests: vec![], - should_fail, + marked_processing: Arc::new(Mutex::new(vec![])), } } } -#[async_trait] -impl WorkerClient for MockWorkerClient { - async fn send( - &mut self, - request: PushTaskRequest, - _grpc_shared_secret: &[String], - ) -> Result<()> { - self.captured_requests.push(request); - if self.should_fail { - return Err(anyhow!("mock send failure")); - } - Ok(()) - } -} - -/// Fake worker client that fires a Notify when send() is called. -struct NotifyingWorkerClient { - should_fail: bool, - notify: Arc, -} - -#[async_trait] -impl WorkerClient for NotifyingWorkerClient { - async fn send(&mut self, _request: PushTaskRequest, _: &[String]) -> Result<()> { - self.notify.notify_one(); - if self.should_fail { - return Err(anyhow!("mock send failure")); - } - Ok(()) - } -} - -/// Minimal fake store that records which activation IDs have been marked as processing. -#[derive(Default, Clone)] -struct MockStore { - marked_processing: Arc>>, -} - impl MockStore { - fn marked_ids(&self) -> Vec { + fn marked(&self) -> Vec { self.marked_processing.lock().unwrap().clone() } } @@ -79,9 +42,11 @@ impl InflightActivationStore for MockStore { async fn store(&self, _batch: Vec) -> anyhow::Result { Ok(0) } + fn assign_partitions(&self, _partitions: Vec) -> anyhow::Result<()> { Ok(()) } + async fn claim_activations( &self, _application: Option<&str>, @@ -92,10 +57,22 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(vec![]) } - async fn mark_activation_processing(&self, id: &str) -> anyhow::Result<()> { + + async fn mark_processing(&self, id: &str) -> anyhow::Result<()> { self.marked_processing.lock().unwrap().push(id.to_string()); Ok(()) } + + async fn mark_processing_batch(&self, ids: &[String]) -> anyhow::Result { + let mut guard = self.marked_processing.lock().unwrap(); + + for id in ids { + guard.push(id.clone()); + } + + Ok(ids.len() as u64) + } + async fn set_status( &self, _id: &str, @@ -103,6 +80,7 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(None) } + async fn set_status_batch( &self, _ids: &[String], @@ -110,18 +88,23 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result { Ok(0) } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { 0.0 } + async fn count_by_status(&self, _status: InflightActivationStatus) -> anyhow::Result { Ok(0) } + async fn count(&self) -> anyhow::Result { Ok(0) } + async fn get_by_id(&self, _id: &str) -> anyhow::Result> { Ok(None) } + async fn set_processing_deadline( &self, _id: &str, @@ -129,227 +112,132 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result<()> { Ok(()) } + async fn delete_activation(&self, _id: &str) -> anyhow::Result<()> { Ok(()) } + async fn vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn full_vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn db_size(&self) -> anyhow::Result { Ok(0) } + async fn get_retry_activations(&self) -> anyhow::Result> { Ok(vec![]) } + async fn handle_claim_expiration(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_deadline(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_attempts(&self) -> anyhow::Result { Ok(0) } + async fn handle_expires_at(&self) -> anyhow::Result { Ok(0) } + async fn handle_delay_until(&self) -> anyhow::Result { Ok(0) } + async fn handle_failed_tasks(&self) -> anyhow::Result { Ok(FailedTasksForwarder { to_discard: vec![], to_deadletter: vec![], }) } + async fn mark_completed(&self, _ids: Vec) -> anyhow::Result { Ok(0) } + async fn remove_completed(&self) -> anyhow::Result { Ok(0) } + async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> anyhow::Result { Ok(0) } + async fn clear(&self) -> anyhow::Result<()> { Ok(()) } } -/// Factory that fires `notify` when send() is called, then succeeds or fails per `should_fail`. -fn notifying_factory(should_fail: bool, notify: Arc) -> WorkerFactory { - Arc::new(move |_: String| { - let notify = notify.clone(); - Box::pin(async move { - Ok(Box::new(NotifyingWorkerClient { - should_fail, - notify, - }) as Box) - }) - }) -} - -/// Factory that always fails to connect (simulates a broken endpoint). -fn failing_connect_factory() -> WorkerFactory { - Arc::new(|_: String| Box::pin(async { Err(anyhow::anyhow!("simulated connect failure")) })) -} - -#[tokio::test] -async fn push_task_returns_ok_on_client_success() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(false); - let callback_url = "taskbroker:50051".to_string(); - - let result = push_task( - &mut worker, - activation.clone(), - callback_url.clone(), - Duration::from_secs(5), - &[], - ) - .await; - assert!(result.is_ok(), "push_task should succeed"); - assert_eq!(worker.captured_requests.len(), 1); - - let request = &worker.captured_requests[0]; - assert_eq!(request.callback_url, callback_url); - assert_eq!( - request.task.as_ref().map(|task| task.id.as_str()), - Some(activation.id.as_str()) - ); -} - #[tokio::test] -async fn push_task_returns_err_on_invalid_payload() { - let mut activation = make_activations(1).remove(0); - activation.activation = vec![1, 2, 3, 4]; - - let mut worker = MockWorkerClient::new(false); - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; - - assert!(result.is_err(), "invalid payload should fail decoding"); - assert!( - worker.captured_requests.is_empty(), - "worker should not be called if decode fails" - ); -} - -#[tokio::test] -async fn push_task_propagates_client_error() { - let activation = make_activations(1).remove(0); - let mut worker = MockWorkerClient::new(true); - - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; - assert!(result.is_err(), "worker send errors should propagate"); - assert_eq!(worker.captured_requests.len(), 1); -} - -#[tokio::test] -async fn push_pool_submit_enqueues_item() { +async fn push_pool_push_task_enqueues_item() { let config = Arc::new(Config { push_queue_size: 2, ..Config::default() }); - let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); let activation = make_activations(1).remove(0); + let pool = PushPool::new(config); + let time = Instant::now(); - let result = pool.submit(activation, time).await; - assert!(result.is_ok(), "submit should enqueue activation"); + let result = pool.push_task(activation, time).await; + assert!(result.is_ok(), "push_task should enqueue activation"); } #[tokio::test] -async fn push_pool_submit_backpressures_when_queue_full() { +async fn push_pool_push_task_backpressures_when_queue_full() { let config = Arc::new(Config { push_queue_size: 1, ..Config::default() }); - let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); - - let time = Instant::now(); let first = make_activations(1).remove(0); let second = make_activations(1).remove(0); - pool.submit(first, time) - .await - .expect("first submit should fill queue"); + let pool = PushPool::new(config); - let second_submit = timeout(Duration::from_millis(50), pool.submit(second, time)).await; - assert!( - second_submit.is_err(), - "second submit should block when queue is full" - ); -} - -#[test] -fn sentry_signature_hex_matches_hmac_contract() { - let digest = sentry_signature_hex("super secret", "/test/path", b"hello"); - assert_eq!( - digest, - "6408482d9e6d4975ada4c0302fda813c5718e571e6f9a2d6e2803cb48528044e" - ); -} - -/// When the worker factory fails to connect, start() returns an error immediately. -#[tokio::test] -async fn push_pool_start_worker_connect_failure_returns_error() { - let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), - push_threads: 1, - push_queue_size: 10, - ..Config::default() - }); - let store = Arc::new(MockStore::default()); - let pool = PushPool::new_with_factory(config, store, failing_connect_factory()); + let time = Instant::now(); + pool.push_task(first, time) + .await + .expect("first task should fill the queue"); - let result = pool.start().await; + let second_push = timeout(Duration::from_millis(50), pool.push_task(second, time)).await; assert!( - result.is_err(), - "start() should return Err when the worker factory fails to connect" + second_push.is_err(), + "second push_task should time out when queue is full" ); } -/// After a successful push for a first-attempt activation (processing_attempts == 0), -/// mark_activation_processing must be called on the store. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_first_attempt() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), push_threads: 1, push_queue_size: 10, ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(false, notify.clone()), - )); + let pool = Arc::new(PushPool::new(config)); + + let workers = vec![test_worker_map(false, notify.clone())]; + let updater = test_eager_updater(store.clone()); let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); + tokio::spawn(async move { + pool_start + .start(workers, updater) + .await + .expect("push pool start"); + }); let activation = make_activations(1).remove(0); assert_eq!(activation.processing_attempts, 0); @@ -357,43 +245,43 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { let id = activation.id.clone(); let time = Instant::now(); - pool.submit(activation, time) + pool.push_task(activation, time) .await - .expect("submit should succeed"); + .expect("push_task should succeed"); - // Wait for the worker to call send(), then give it time to call mark_activation_processing timeout(Duration::from_secs(2), notify.notified()) .await .expect("timed out waiting for push to be delivered"); tokio::time::sleep(Duration::from_millis(50)).await; assert_eq!( - store.marked_ids(), + store.marked(), vec![id], - "mark_activation_processing should be called after a successful first-attempt push" + "mark_processing should be called after a successful first-attempt push" ); } -/// After a successful push for a retried activation (processing_attempts > 0), -/// mark_activation_processing must be called and latency recording is skipped. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_retry() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), push_threads: 1, push_queue_size: 10, ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(false, notify.clone()), - )); + let pool = Arc::new(PushPool::new(config)); + + let workers = vec![test_worker_map(false, notify.clone())]; + let updater = test_eager_updater(store.clone()); let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); + tokio::spawn(async move { + pool_start + .start(workers, updater) + .await + .expect("push pool start"); + }); let mut activation = make_activations(1).remove(0); activation.processing_attempts = 1; @@ -401,9 +289,9 @@ async fn push_pool_start_marks_activation_processing_on_retry() { let id = activation.id.clone(); let time = Instant::now(); - pool.submit(activation, time) + pool.push_task(activation, time) .await - .expect("submit should succeed"); + .expect("push_task should succeed"); timeout(Duration::from_secs(2), notify.notified()) .await @@ -411,38 +299,40 @@ async fn push_pool_start_marks_activation_processing_on_retry() { tokio::time::sleep(Duration::from_millis(50)).await; assert_eq!( - store.marked_ids(), + store.marked(), vec![id], - "mark_activation_processing should be called after a successful retry push" + "mark_processing should be called after a successful retry push" ); } -/// When the worker fails to deliver an activation, mark_activation_processing must NOT be called. #[tokio::test] -async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { +async fn push_pool_start_does_not_mark_processing_on_push_failure() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { - worker_map: [("sentry".into(), "unused".into())].into(), push_threads: 1, push_queue_size: 10, ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = Arc::new(PushPool::new_with_factory( - config, - store.clone(), - notifying_factory(true, notify.clone()), - )); + let pool = Arc::new(PushPool::new(config)); + + let workers = vec![test_worker_map(true, notify.clone())]; + let updater = test_eager_updater(store.clone()); let pool_start = pool.clone(); - tokio::spawn(async move { pool_start.start().await }); + tokio::spawn(async move { + pool_start + .start(workers, updater) + .await + .expect("push pool start"); + }); let activation = make_activations(1).remove(0); let time = Instant::now(); - pool.submit(activation, time) + pool.push_task(activation, time) .await - .expect("submit should succeed"); + .expect("push_task should succeed"); timeout(Duration::from_secs(2), notify.notified()) .await @@ -450,7 +340,7 @@ async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { tokio::time::sleep(Duration::from_millis(50)).await; assert!( - store.marked_ids().is_empty(), - "mark_activation_processing should not be called when push fails" + store.marked().is_empty(), + "mark_processing should not be called when push fails" ); } diff --git a/src/push/thread.rs b/src/push/thread.rs new file mode 100644 index 00000000..e5b1304c --- /dev/null +++ b/src/push/thread.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; +use std::time::Instant; + +use anyhow::Result; +use elegant_departure::get_shutdown_guard; +use flume::Receiver; + +use crate::push::updater::Updater; +use crate::store::activation::InflightActivation; +use crate::worker::WorkerMap; + +/// Alias for ergonomics. +pub type Submission = (InflightActivation, Instant); + +/// Abstraction for a single push thread. +pub struct PushThread { + /// Maps every application to its worker service. + pub(super) workers: WorkerMap, + + /// Channel containing claimed activations to be pushed. + pub(super) receiver: Receiver, + + /// Entity that marks tasks as processing. + pub(super) updater: Arc, +} + +impl PushThread { + pub async fn start(&mut self) -> Result<()> { + let guard = get_shutdown_guard(); + + loop { + // We cannot exit before every fetch thread has exited, so don't exit on `guard.wait()` here + tokio::select! { + message = self.receiver.recv_async() => { + let (activation, time) = match message { + // Received activation from fetch thread + Ok(a) => a, + + // We only exit when the push queue is closed, which happens when the fetch pool has shut down + Err(_) => break, + }; + + metrics::histogram!("push.queue.latency").record(time.elapsed()); + + // Push the task and mark it as processing + if let Err(_) = self.push_task(activation).await { + todo!() + } + } + } + } + + // Drain channel before exiting + let activations: Vec<_> = self.receiver.drain().collect(); + + for (activation, time) in activations { + metrics::histogram!("push.queue.latency").record(time.elapsed()); + + // Push the task and mark it as processing + if let Err(_) = self.push_task(activation).await { + todo!() + } + } + + drop(guard); + Ok(()) + } + + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + // Store the ID for later since `push_task` claims ownership over `activation` + let id = activation.id.clone(); + + // First, determine the correct worker service + let Some(worker) = self.workers.get_mut(&activation.application) else { + // Missing application to worker mapping + todo!() + }; + + // Then, push the task to that service + if let Err(__) = worker.push_task(activation).await { + todo!() + } + + // Finally, mark the activation as processing + self.updater.update(id).await + } +} diff --git a/src/push/updater.rs b/src/push/updater.rs new file mode 100644 index 00000000..69cff54c --- /dev/null +++ b/src/push/updater.rs @@ -0,0 +1,174 @@ +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use chrono::{DateTime, Utc}; +use elegant_departure::get_shutdown_guard; +use tokio::sync::{Mutex, MutexGuard, Notify}; +use tonic::async_trait; +use tracing::{info, warn}; + +use crate::config::Config; +use crate::store::traits::InflightActivationStore; + +/// Represents an entity that can update tasks in some way. Meant to abstract away +/// the update logic, which varies between delivery modes, batching configurations, and so on. +#[async_trait] +pub trait Updater: Send + Sync { + /// Start the updater. Useful for updaters that run a background task. + async fn start(&self) -> Result<()> { + Ok(()) + } + + /// Update activation in some way given its ID. + async fn update(&self, id: String) -> Result<()>; + + /// Stop the updater. Useful for updaters that run a background task. + fn stop(&self) {} +} + +/// Used by push threads to update sent activations from "claimed" to "processing" in batches. +pub struct LazyUpdater { + /// The taskbroker configuration. + config: Arc, + + /// The activation store. + store: Arc, + + /// Last time the buffer was flushed. + last_flush: DateTime, + + /// Sent activations that need to be updated. + buffer: Arc>>, + + /// Signals the background task to stop. + stop: Notify, +} + +impl LazyUpdater { + pub fn new(config: Arc, store: Arc) -> Self { + let buffer = Arc::new(Mutex::new(vec![])); + let last_flush = Utc::now(); + let stop = Notify::new(); + + Self { + config, + store, + buffer, + last_flush, + stop, + } + } + + /// Flush buffered activations to the store. Empties the buffer on success, refills on failure. + async fn flush(&self, buffer: &mut MutexGuard<'_, Vec>) -> Result<()> { + let ids: Vec<_> = buffer.drain(..).collect(); + let expected = ids.len() as u64; + + match self.store.mark_processing_batch(&ids).await { + Ok(actual) => { + if actual < expected { + // This may happen if tasks are reverted back to pending OR completed too quickly + warn!( + "Push thread update batch contained {expected} records, but only {actual} were updated" + ); + } + + Ok(()) + } + + Err(e) => { + // Flush failed, return IDs to buffer + buffer.extend(ids); + Err(e) + } + } + } +} + +#[async_trait] +impl Updater for LazyUpdater { + async fn start(&self) -> Result<()> { + // Hold guard until the updater has stopped to delay shutdown + let guard = get_shutdown_guard(); + + // Flush every `interval` milliseconds + let period = Duration::from_millis(self.config.push_update_interval_ms as u64); + let mut interval = tokio::time::interval(period); + + loop { + tokio::select! { + _ = self.stop.notified() => { + info!("Stopping lazy updater..."); + break; + } + + _ = interval.tick(), if self.config.batch_push_updates => { + // Lock the ID buffer + let mut buffer = self.buffer.lock().await; + + // Make sure we aren't flushing too soon + let now = Utc::now().timestamp_millis(); + let elapsed = self.last_flush.timestamp_millis() - now; + + if elapsed < (self.config.push_update_interval_ms as i64) { + // Too soon! + continue; + } + + // We can propagate the error upwards here if desired + if let Err(_) = self.flush(&mut buffer).await { + todo!() + } + } + } + } + + drop(guard); + Ok(()) + } + + async fn update(&self, id: String) -> Result<()> { + // Lock the ID buffer + let mut buffer = self.buffer.lock().await; + + if buffer.len() >= self.config.push_update_batch_size { + // Flush first + if let Err(_) = self.flush(&mut buffer).await { + todo!() + } + } + + buffer.push(id); + Ok(()) + } + + fn stop(&self) { + self.stop.notify_one(); + } +} + +/// Used by push threads to update sent activations from "claimed" to "processing" individually. +pub struct EagerUpdater { + /// The activation store. + store: Arc, +} + +impl EagerUpdater { + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +#[async_trait] +impl Updater for EagerUpdater { + async fn update(&self, id: String) -> Result<()> { + self.store.mark_processing(&id).await + } +} + +#[cfg(test)] +pub fn test_eager_updater(store: Arc) -> Arc { + let eager = EagerUpdater::new(store); + Arc::new(eager) as Arc +} diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 27ce276f..08e3bfed 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -15,6 +15,7 @@ use sentry_protos::taskbroker::v1::OnAttemptsExceeded; use tracing::{instrument, warn}; use crate::config::Config; +use crate::push::compute_claim_lease_ms; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::{BucketRange, DepthCounts, FailedTasksForwarder}; @@ -135,7 +136,6 @@ pub struct PostgresActivationStoreConfig { pub run_migrations: bool, pub max_processing_attempts: usize, pub processing_deadline_grace_sec: u64, - /// Milliseconds added to `claim_expires_at` before grace: `fetch_batch_size * push_queue_timeout_ms`. pub claim_lease_ms: u64, pub vacuum_page_count: Option, pub enable_sqlite_status_metrics: bool, @@ -148,12 +148,14 @@ impl PostgresActivationStoreConfig { .password(&config.pg_password) .host(&config.pg_host) .port(config.pg_port); + if let Some(extra_query_params) = config.pg_extra_query_params.as_ref() { let url = conn_opts.to_url_lossy(); let new_url = url.as_ref().split('?').next().unwrap().to_string() + "?" + extra_query_params; conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); } + Self { pg_connection: conn_opts, pg_database_name: config.pg_database_name.clone(), @@ -161,9 +163,9 @@ impl PostgresActivationStoreConfig { run_migrations: config.run_migrations, max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, - processing_deadline_grace_sec: config.processing_deadline_grace_sec, - claim_lease_ms: config.fetch_batch_size.max(1) as u64 * config.push_queue_timeout_ms, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + processing_deadline_grace_sec: config.processing_deadline_grace_sec, + claim_lease_ms: compute_claim_lease_ms(config), } } } @@ -421,9 +423,6 @@ impl InflightActivationStore for PostgresActivationStore { ) -> Result, Error> { let now = Utc::now(); - let grace_period = self.config.processing_deadline_grace_sec; - let claim_lease_ms = self.config.claim_lease_ms as i64; - let mut query_builder = QueryBuilder::::new( "WITH selected_activations AS ( SELECT id @@ -469,6 +468,8 @@ impl InflightActivationStore for PostgresActivationStore { query_builder.push(" FOR UPDATE SKIP LOCKED)"); if mark_processing { + let grace_period = self.config.processing_deadline_grace_sec; + query_builder.push(format!( "UPDATE inflight_taskactivations SET processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), @@ -478,9 +479,11 @@ impl InflightActivationStore for PostgresActivationStore { query_builder.push_bind(InflightActivationStatus::Processing.to_string()); } else { + let claim_lease = self.config.claim_lease_ms as i64; + query_builder.push(format!( "UPDATE inflight_taskactivations - SET claim_expires_at = now() + ({claim_lease_ms} * interval '1 millisecond') + (interval '{grace_period} seconds'), + SET claim_expires_at = now() + ({claim_lease} * interval '1 millisecond'), processing_deadline = NULL, status = " )); @@ -503,11 +506,8 @@ impl InflightActivationStore for PostgresActivationStore { #[instrument(skip_all)] #[framed] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; - + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( "UPDATE inflight_taskactivations SET @@ -523,20 +523,47 @@ impl InflightActivationStore for PostgresActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as processing, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + #[framed] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let result = sqlx::query(&format!( + "UPDATE inflight_taskactivations SET + status = $1, + processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), + claim_expires_at = NULL + WHERE id = ANY($2) AND status = $3", + )) + .bind(InflightActivationStatus::Processing.to_string()) + .bind(ids) + .bind(InflightActivationStatus::Claimed.to_string()) + .execute(&mut *conn) + .await?; + + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index de457de4..f558e9c6 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -24,6 +24,7 @@ use sentry_protos::taskbroker::v1::OnAttemptsExceeded; use tracing::{instrument, warn}; use crate::config::Config; +use crate::push::compute_claim_lease_ms; use crate::store::activation::{InflightActivation, InflightActivationStatus}; use crate::store::traits::InflightActivationStore; use crate::store::types::{BucketRange, FailedTasksForwarder}; @@ -139,7 +140,6 @@ pub async fn create_sqlite_pool(url: &str) -> Result<(Pool, Pool pub struct InflightActivationStoreConfig { pub max_processing_attempts: usize, pub processing_deadline_grace_sec: u64, - /// Milliseconds added to `claim_expires_at` before grace: `fetch_batch_size * push_queue_timeout_ms`. pub claim_lease_ms: u64, pub vacuum_page_count: Option, pub enable_sqlite_status_metrics: bool, @@ -151,8 +151,8 @@ impl InflightActivationStoreConfig { max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, processing_deadline_grace_sec: config.processing_deadline_grace_sec, - claim_lease_ms: config.fetch_batch_size.max(1) as u64 * config.push_queue_timeout_ms, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + claim_lease_ms: compute_claim_lease_ms(config), } } } @@ -535,20 +535,22 @@ impl InflightActivationStore for SqliteActivationStore { mark_processing: bool, ) -> Result, Error> { let now = Utc::now(); - let grace_period = self.config.processing_deadline_grace_sec; let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations SET "); if mark_processing { + let grace_period = self.config.processing_deadline_grace_sec; + query_builder.push(format!( "processing_deadline = unixepoch('now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds'), claim_expires_at = NULL, status = " )); query_builder.push_bind(InflightActivationStatus::Processing); } else { + let claim_lease = self.config.claim_lease_ms as f64 / 1000.0; + query_builder.push(format!( - "claim_expires_at = unixepoch('now', '+' || {:.3} || ' seconds', '+' || {grace_period} || ' seconds'), processing_deadline = NULL, status = ", - self.config.claim_lease_ms as f64 / 1000.0, + "claim_expires_at = unixepoch('now', '+' || {claim_lease:.3} || ' seconds'), processing_deadline = NULL, status = " )); query_builder.push_bind(InflightActivationStatus::Claimed); @@ -597,10 +599,8 @@ impl InflightActivationStore for SqliteActivationStore { } #[instrument(skip_all)] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( @@ -617,20 +617,48 @@ impl InflightActivationStore for SqliteActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as sent, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations SET status = "); + query_builder.push_bind(InflightActivationStatus::Processing); + query_builder.push(format!( + ", processing_deadline = unixepoch('now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds'), claim_expires_at = NULL WHERE status = ", + )); + query_builder.push_bind(InflightActivationStatus::Claimed); + query_builder.push(" AND id IN ("); + + let mut separated = query_builder.separated(", "); + for id in ids.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + + let result = query_builder.build().execute(&mut *conn).await?; + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/traits.rs b/src/store/traits.rs index d8bdb5e0..287fc12f 100644 --- a/src/store/traits.rs +++ b/src/store/traits.rs @@ -27,7 +27,7 @@ pub trait InflightActivationStore: Send + Sync { mark_processing: bool, ) -> Result, Error>; - /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_activation_processing` moves to `Processing`. + /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_processing` moves to `Processing`. async fn claim_activations_for_push( &self, limit: Option, @@ -69,7 +69,10 @@ pub trait InflightActivationStore: Send + Sync { } /// Record successful push. - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error>; + async fn mark_processing(&self, id: &str) -> Result<(), Error>; + + /// Record a batch of successful pushes. + async fn mark_processing_batch(&self, ids: &[String]) -> Result; /// Update the status of a specific activation async fn set_status( diff --git a/src/tokio.rs b/src/tokio.rs index 54c593a9..50fe2ed8 100644 --- a/src/tokio.rs +++ b/src/tokio.rs @@ -3,9 +3,9 @@ use tokio::task::JoinSet; /// Spawns `max(n, 1)` tasks, each running the future produced by `f` with the task's index. /// Returns a [`JoinSet`] containing all spawned tasks. -pub fn spawn_pool(n: usize, f: F) -> JoinSet +pub fn spawn_pool(n: usize, mut f: F) -> JoinSet where - F: Fn(usize) -> Fut, + F: FnMut(usize) -> Fut, Fut: Future + Send + 'static, Fut::Output: Send, { diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 00000000..a6409306 --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,237 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{Context, Result, anyhow}; +use async_backtrace::framed; +use hmac::{Hmac, Mac}; +use prost::Message; +use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; +use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use sha2::Sha256; +#[cfg(test)] +use tokio::sync::Notify; +use tonic::metadata::MetadataValue; +use tonic::transport::Channel; +use tonic::{Request, async_trait}; + +use crate::config::Config; +use crate::store::activation::InflightActivation; + +// Alias for ergonomics. +pub type WorkerMap = HashMap>; + +/// gRPC path for `WorkerService::PushTask` that should be kept in sync with `sentry_protos` generated client. +pub const WORKER_PUSH_TASK_PATH: &str = "/sentry_protos.taskbroker.v1.WorkerService/PushTask"; + +/// Helper to compute HMAC-SHA256(secret, grpc_path + ":" + message) in hexadecimal. Matches Python `RequestSignatureInterceptor`. +fn sentry_signature_hex(secret: &str, grpc_path: &str, message: &[u8]) -> String { + let mut mac = + Hmac::::new_from_slice(secret.as_bytes()).expect("HMAC accepts keys of any length"); + mac.update(grpc_path.as_bytes()); + mac.update(b":"); + mac.update(message); + hex::encode(mac.finalize().into_bytes()) +} + +/// Thin interface for the worker client. It mostly serves to enable proper unit testing, +/// but it also decouples the actual client implementation from our pushing logic. +#[async_trait] +pub trait WorkerClient: Send + Sync { + /// Send a single activation to the worker service. + async fn push_task(&mut self, activation: InflightActivation) -> Result<()>; +} + +/// Wrapper around worker connection that provides authentication and timeouts. +pub struct Worker { + /// Connection to the worker service. + client: WorkerServiceClient, + + /// List of secrets shared with the worker. + secrets: Vec, + + /// Wait this much time before giving up on a push. + timeout: Duration, +} + +impl Worker { + pub async fn connect(config: Arc, endpoint: String) -> Result { + let client = WorkerServiceClient::connect(endpoint).await?; + + let secrets = config.grpc_shared_secret.clone(); + let timeout = Duration::from_millis(config.push_timeout_ms); + + Ok(Self { + client, + secrets, + timeout, + }) + } +} + +#[async_trait] +impl WorkerClient for Worker { + #[framed] + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + // Try to decode activation + let task = + TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; + + // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema + let request = PushTaskRequest { + task: Some(task), + callback_url: "".into(), + }; + + // Wrap inside a Tonic request + let mut request = Request::new(request); + + // Sign if secrets are present + if let Some(secret) = self.secrets.first() { + let body = request.get_ref().encode_to_vec(); + let signature = sentry_signature_hex(secret, WORKER_PUSH_TASK_PATH, &body); + let value = MetadataValue::try_from(signature.as_str()) + .context("sentry-signature metadata value must be valid ASCII")?; + + request.metadata_mut().insert("sentry-signature", value); + } + + // Push with timeout + let future = self.client.push_task(request); + tokio::time::timeout(self.timeout, future).await??; + + Ok(()) + } +} + +/// Fake worker client that records pushed activation IDs and optionally fails. +#[cfg(test)] +struct MockWorkerClient { + /// Pushed activation IDs. + pushed: Vec, + + /// Should `push_task` fail? + fail: bool, +} + +#[cfg(test)] +impl MockWorkerClient { + fn new(fail: bool) -> Self { + Self { + pushed: vec![], + fail, + } + } +} + +#[cfg(test)] +#[async_trait] +impl WorkerClient for MockWorkerClient { + async fn push_task(&mut self, activation: InflightActivation) -> Result<()> { + TaskActivation::decode(&activation.activation as &[u8]).map_err(|e| anyhow!(e))?; + self.pushed.push(activation.id); + + if self.fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } +} + +/// Fake worker client that fires a `Notify` when `push_task` is called. +#[cfg(test)] +struct NotifyingWorkerClient { + /// Fire off notification when `push_task` is called + notify: Arc, + + /// Should `push_task` fail? + fail: bool, +} + +#[cfg(test)] +#[async_trait] +impl WorkerClient for NotifyingWorkerClient { + async fn push_task(&mut self, _activation: InflightActivation) -> Result<()> { + self.notify.notify_one(); + + if self.fail { + return Err(anyhow!("mock send failure")); + } + + Ok(()) + } +} + +/// Create a map of notifying worker clients for tests. +#[cfg(test)] +pub fn test_worker_map(fail: bool, notify: Arc) -> WorkerMap { + let mut workers = HashMap::new(); + + let client = NotifyingWorkerClient { fail, notify }; + + workers.insert("sentry".into(), Box::new(client) as Box); + workers +} + +#[cfg(test)] +mod tests { + use hmac::{Hmac, Mac}; + use sha2::Sha256; + + use crate::test_utils::make_activations; + + use super::MockWorkerClient; + use super::WORKER_PUSH_TASK_PATH; + use super::WorkerClient; + + #[tokio::test] + async fn worker_push_task_returns_ok_on_client_success() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(false); + + let result = worker.push_task(activation.clone()).await; + assert!(result.is_ok(), "push_task should succeed"); + assert_eq!(worker.pushed, vec![activation.id]); + } + + #[tokio::test] + async fn worker_push_task_returns_err_on_invalid_payload() { + let mut activation = make_activations(1).remove(0); + activation.activation = vec![1, 2, 3, 4]; + + let mut worker = MockWorkerClient::new(false); + let result = worker.push_task(activation).await; + + assert!(result.is_err(), "invalid payload should fail decoding"); + assert!( + worker.pushed.is_empty(), + "worker should not record a push if decode fails" + ); + } + + #[tokio::test] + async fn worker_push_task_propagates_client_error() { + let activation = make_activations(1).remove(0); + let mut worker = MockWorkerClient::new(true); + + let result = worker.push_task(activation.clone()).await; + assert!(result.is_err(), "worker push errors should propagate"); + assert_eq!(worker.pushed, vec![activation.id]); + } + + #[test] + fn worker_sentry_signature_hex_matches_hmac_contract() { + let mut mac = Hmac::::new_from_slice(b"super secret") + .expect("HMAC accepts keys of any length"); + mac.update(WORKER_PUSH_TASK_PATH.as_bytes()); + mac.update(b":"); + mac.update(b"hello"); + let digest = hex::encode(mac.finalize().into_bytes()); + + assert_eq!( + digest, + "6408482d9e6d4975ada4c0302fda813c5718e571e6f9a2d6e2803cb48528044e" + ); + } +}