Skip to content

Commit aee4858

Browse files
committed
feat(http): add rate limiting and CORS middleware
1 parent 97c7c86 commit aee4858

1 file changed

Lines changed: 186 additions & 5 deletions

File tree

src/http.rs

Lines changed: 186 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
use std::collections::HashMap;
12
use std::sync::Arc;
3+
use std::time::Instant;
24

35
use axum::extract::{Path, Query, State};
4-
use axum::http::HeaderMap;
6+
use axum::http::{HeaderMap, HeaderValue, Method};
57
use axum::{Json, Router, http::StatusCode, response::IntoResponse, routing::{get, post}};
68
use serde::Deserialize;
79
use tokio::sync::Mutex;
10+
use tower_http::cors::{Any, CorsLayer};
811

912
use crate::config::{ApiKeyConfig, HttpConfig};
1013
use crate::context::{self, ContextParams};
@@ -34,6 +37,56 @@ pub struct ApiState {
3437
pub http_config: Arc<HttpConfig>,
3538
pub no_auth: bool,
3639
pub recent_writes: RecentWrites,
40+
pub rate_limiter: Arc<RateLimiter>,
41+
}
42+
43+
// ---------------------------------------------------------------------------
44+
// Rate limiter (in-memory token bucket)
45+
// ---------------------------------------------------------------------------
46+
47+
pub struct RateLimiter {
48+
buckets: std::sync::Mutex<HashMap<String, RateBucket>>,
49+
limit: u32, // requests per minute, 0 = unlimited
50+
}
51+
52+
struct RateBucket {
53+
tokens: u32,
54+
last_refill: Instant,
55+
}
56+
57+
impl RateLimiter {
58+
pub fn new(limit: u32) -> Self {
59+
Self {
60+
buckets: std::sync::Mutex::new(HashMap::new()),
61+
limit,
62+
}
63+
}
64+
65+
/// Check if a request is allowed. Returns Ok(()) or Err with retry-after seconds.
66+
pub fn check(&self, key: &str) -> Result<(), u64> {
67+
if self.limit == 0 {
68+
return Ok(());
69+
}
70+
let mut buckets = self.buckets.lock().unwrap();
71+
let bucket = buckets.entry(key.to_string()).or_insert(RateBucket {
72+
tokens: self.limit,
73+
last_refill: Instant::now(),
74+
});
75+
// Refill tokens based on elapsed time
76+
let elapsed = bucket.last_refill.elapsed().as_secs_f64();
77+
let refill = (elapsed * self.limit as f64 / 60.0) as u32;
78+
if refill > 0 {
79+
bucket.tokens = (bucket.tokens + refill).min(self.limit);
80+
bucket.last_refill = Instant::now();
81+
}
82+
if bucket.tokens > 0 {
83+
bucket.tokens -= 1;
84+
Ok(())
85+
} else {
86+
let retry_after = (60.0 / self.limit as f64).ceil() as u64;
87+
Err(retry_after)
88+
}
89+
}
3790
}
3891

3992
// ---------------------------------------------------------------------------
@@ -43,12 +96,22 @@ pub struct ApiState {
4396
pub struct ApiError {
4497
pub status: StatusCode,
4598
pub message: String,
99+
pub headers: Vec<(String, String)>,
46100
}
47101

48102
impl IntoResponse for ApiError {
49103
fn into_response(self) -> axum::response::Response {
50104
let body = serde_json::json!({ "error": self.message });
51-
(self.status, Json(body)).into_response()
105+
let mut response = (self.status, Json(body)).into_response();
106+
for (name, value) in &self.headers {
107+
if let (Ok(n), Ok(v)) = (
108+
axum::http::header::HeaderName::from_bytes(name.as_bytes()),
109+
HeaderValue::from_str(value),
110+
) {
111+
response.headers_mut().insert(n, v);
112+
}
113+
}
114+
response
52115
}
53116
}
54117

@@ -57,30 +120,42 @@ impl ApiError {
57120
Self {
58121
status: StatusCode::UNAUTHORIZED,
59122
message: msg.to_string(),
123+
headers: vec![],
60124
}
61125
}
62126
pub fn forbidden(msg: &str) -> Self {
63127
Self {
64128
status: StatusCode::FORBIDDEN,
65129
message: msg.to_string(),
130+
headers: vec![],
66131
}
67132
}
68133
pub fn bad_request(msg: &str) -> Self {
69134
Self {
70135
status: StatusCode::BAD_REQUEST,
71136
message: msg.to_string(),
137+
headers: vec![],
72138
}
73139
}
74140
pub fn not_found(msg: &str) -> Self {
75141
Self {
76142
status: StatusCode::NOT_FOUND,
77143
message: msg.to_string(),
144+
headers: vec![],
78145
}
79146
}
80147
pub fn internal(msg: &str) -> Self {
81148
Self {
82149
status: StatusCode::INTERNAL_SERVER_ERROR,
83150
message: msg.to_string(),
151+
headers: vec![],
152+
}
153+
}
154+
pub fn rate_limited(retry_after: u64) -> Self {
155+
Self {
156+
status: StatusCode::TOO_MANY_REQUESTS,
157+
message: format!("Rate limit exceeded. Retry after {retry_after}s"),
158+
headers: vec![("retry-after".to_string(), retry_after.to_string())],
84159
}
85160
}
86161
}
@@ -102,13 +177,17 @@ pub fn check_permission(permission: &str, is_write: bool) -> bool {
102177
permission == "write"
103178
}
104179

105-
/// Extract and validate auth from request headers.
180+
/// Extract and validate auth from request headers, then check rate limit.
106181
pub fn authorize(
107182
headers: &axum::http::HeaderMap,
108183
state: &ApiState,
109184
is_write: bool,
110185
) -> Result<(), ApiError> {
111186
if state.no_auth {
187+
state
188+
.rate_limiter
189+
.check("no_auth")
190+
.map_err(ApiError::rate_limited)?;
112191
return Ok(());
113192
}
114193
let auth = headers
@@ -125,6 +204,10 @@ pub fn authorize(
125204
"Insufficient permissions: write access required",
126205
));
127206
}
207+
state
208+
.rate_limiter
209+
.check(key)
210+
.map_err(ApiError::rate_limited)?;
128211
Ok(())
129212
}
130213

@@ -237,12 +320,38 @@ struct DeleteBody {
237320
mode: Option<String>,
238321
}
239322

323+
// ---------------------------------------------------------------------------
324+
// CORS
325+
// ---------------------------------------------------------------------------
326+
327+
fn cors_layer(origins: &[String]) -> CorsLayer {
328+
if origins.is_empty() {
329+
return CorsLayer::new();
330+
}
331+
if origins.iter().any(|o| o == "*") {
332+
return CorsLayer::new()
333+
.allow_origin(Any)
334+
.allow_methods(Any)
335+
.allow_headers(Any);
336+
}
337+
let origins: Vec<HeaderValue> = origins
338+
.iter()
339+
.filter_map(|o| o.parse().ok())
340+
.collect();
341+
CorsLayer::new()
342+
.allow_origin(origins)
343+
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
344+
.allow_headers(Any)
345+
.allow_credentials(true)
346+
}
347+
240348
// ---------------------------------------------------------------------------
241349
// Router
242350
// ---------------------------------------------------------------------------
243351

244352
/// Build the axum router with all API endpoints.
245353
pub fn build_router(state: ApiState) -> Router {
354+
let cors = cors_layer(&state.http_config.cors_origins);
246355
Router::new()
247356
.route("/api/health-check", get(health_check))
248357
.route("/api/search", post(handle_search))
@@ -265,6 +374,7 @@ pub fn build_router(state: ApiState) -> Router {
265374
.route("/api/unarchive", post(handle_unarchive))
266375
.route("/api/update-metadata", post(handle_update_metadata))
267376
.route("/api/delete", post(handle_delete))
377+
.layer(cors)
268378
.with_state(state)
269379
}
270380

@@ -746,7 +856,6 @@ async fn handle_delete(
746856
#[cfg(test)]
747857
mod tests {
748858
use super::*;
749-
use std::collections::HashMap;
750859
use std::path::PathBuf;
751860
use std::time::SystemTime;
752861

@@ -792,16 +901,19 @@ mod tests {
792901

793902
fn test_api_state() -> ApiState {
794903
let store = Store::open_memory().expect("in-memory store");
904+
let config = test_http_config();
905+
let rate_limiter = Arc::new(RateLimiter::new(config.rate_limit));
795906
ApiState {
796907
store: Arc::new(Mutex::new(store)),
797908
embedder: Arc::new(Mutex::new(Box::new(DummyEmbedder) as Box<dyn EmbedModel + Send>)),
798909
vault_path: Arc::new(PathBuf::from("/tmp/test-vault")),
799910
profile: Arc::new(None),
800911
orchestrator: None,
801912
reranker: None,
802-
http_config: Arc::new(test_http_config()),
913+
http_config: Arc::new(config),
803914
no_auth: false,
804915
recent_writes: Arc::new(Mutex::new(HashMap::<PathBuf, SystemTime>::new())),
916+
rate_limiter,
805917
}
806918
}
807919

@@ -1029,4 +1141,73 @@ mod tests {
10291141
// Should be 500 (file not found via store) but NOT 403
10301142
assert_ne!(response.status(), StatusCode::FORBIDDEN);
10311143
}
1144+
1145+
// -----------------------------------------------------------------------
1146+
// Rate limiter unit tests
1147+
// -----------------------------------------------------------------------
1148+
1149+
#[test]
1150+
fn test_rate_limiter_allows_under_limit() {
1151+
let limiter = RateLimiter::new(5);
1152+
for _ in 0..5 {
1153+
assert!(limiter.check("key1").is_ok());
1154+
}
1155+
}
1156+
1157+
#[test]
1158+
fn test_rate_limiter_rejects_over_limit() {
1159+
let limiter = RateLimiter::new(2);
1160+
assert!(limiter.check("key1").is_ok());
1161+
assert!(limiter.check("key1").is_ok());
1162+
assert!(limiter.check("key1").is_err());
1163+
}
1164+
1165+
#[test]
1166+
fn test_rate_limiter_unlimited() {
1167+
let limiter = RateLimiter::new(0);
1168+
for _ in 0..1000 {
1169+
assert!(limiter.check("key1").is_ok());
1170+
}
1171+
}
1172+
1173+
#[test]
1174+
fn test_rate_limiter_separate_keys() {
1175+
let limiter = RateLimiter::new(1);
1176+
assert!(limiter.check("key1").is_ok());
1177+
assert!(limiter.check("key2").is_ok()); // different key, separate bucket
1178+
assert!(limiter.check("key1").is_err()); // key1 exhausted
1179+
}
1180+
1181+
#[tokio::test]
1182+
async fn test_rate_limit_returns_429() {
1183+
let mut state = test_api_state();
1184+
state.rate_limiter = Arc::new(RateLimiter::new(1));
1185+
let app = build_router(state);
1186+
// First request passes (consumes the single token)
1187+
let response = app
1188+
.clone()
1189+
.oneshot(
1190+
axum::http::Request::builder()
1191+
.uri("/api/vault-map")
1192+
.header("authorization", "Bearer eg_readkey")
1193+
.body(Body::empty())
1194+
.unwrap(),
1195+
)
1196+
.await
1197+
.unwrap();
1198+
assert_eq!(response.status(), StatusCode::OK);
1199+
// Second request gets 429
1200+
let response = app
1201+
.oneshot(
1202+
axum::http::Request::builder()
1203+
.uri("/api/vault-map")
1204+
.header("authorization", "Bearer eg_readkey")
1205+
.body(Body::empty())
1206+
.unwrap(),
1207+
)
1208+
.await
1209+
.unwrap();
1210+
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
1211+
assert!(response.headers().get("retry-after").is_some());
1212+
}
10321213
}

0 commit comments

Comments
 (0)