1+ use std:: collections:: HashMap ;
12use std:: sync:: Arc ;
3+ use std:: time:: Instant ;
24
35use axum:: extract:: { Path , Query , State } ;
4- use axum:: http:: HeaderMap ;
6+ use axum:: http:: { HeaderMap , HeaderValue , Method } ;
57use axum:: { Json , Router , http:: StatusCode , response:: IntoResponse , routing:: { get, post} } ;
68use serde:: Deserialize ;
79use tokio:: sync:: Mutex ;
10+ use tower_http:: cors:: { Any , CorsLayer } ;
811
912use crate :: config:: { ApiKeyConfig , HttpConfig } ;
1013use crate :: context:: { self , ContextParams } ;
@@ -34,6 +37,56 @@ pub struct ApiState {
3437 pub http_config : Arc < HttpConfig > ,
3538 pub no_auth : bool ,
3639 pub recent_writes : RecentWrites ,
40+ pub rate_limiter : Arc < RateLimiter > ,
41+ }
42+
43+ // ---------------------------------------------------------------------------
44+ // Rate limiter (in-memory token bucket)
45+ // ---------------------------------------------------------------------------
46+
47+ pub struct RateLimiter {
48+ buckets : std:: sync:: Mutex < HashMap < String , RateBucket > > ,
49+ limit : u32 , // requests per minute, 0 = unlimited
50+ }
51+
52+ struct RateBucket {
53+ tokens : u32 ,
54+ last_refill : Instant ,
55+ }
56+
57+ impl RateLimiter {
58+ pub fn new ( limit : u32 ) -> Self {
59+ Self {
60+ buckets : std:: sync:: Mutex :: new ( HashMap :: new ( ) ) ,
61+ limit,
62+ }
63+ }
64+
65+ /// Check if a request is allowed. Returns Ok(()) or Err with retry-after seconds.
66+ pub fn check ( & self , key : & str ) -> Result < ( ) , u64 > {
67+ if self . limit == 0 {
68+ return Ok ( ( ) ) ;
69+ }
70+ let mut buckets = self . buckets . lock ( ) . unwrap ( ) ;
71+ let bucket = buckets. entry ( key. to_string ( ) ) . or_insert ( RateBucket {
72+ tokens : self . limit ,
73+ last_refill : Instant :: now ( ) ,
74+ } ) ;
75+ // Refill tokens based on elapsed time
76+ let elapsed = bucket. last_refill . elapsed ( ) . as_secs_f64 ( ) ;
77+ let refill = ( elapsed * self . limit as f64 / 60.0 ) as u32 ;
78+ if refill > 0 {
79+ bucket. tokens = ( bucket. tokens + refill) . min ( self . limit ) ;
80+ bucket. last_refill = Instant :: now ( ) ;
81+ }
82+ if bucket. tokens > 0 {
83+ bucket. tokens -= 1 ;
84+ Ok ( ( ) )
85+ } else {
86+ let retry_after = ( 60.0 / self . limit as f64 ) . ceil ( ) as u64 ;
87+ Err ( retry_after)
88+ }
89+ }
3790}
3891
3992// ---------------------------------------------------------------------------
@@ -43,12 +96,22 @@ pub struct ApiState {
4396pub struct ApiError {
4497 pub status : StatusCode ,
4598 pub message : String ,
99+ pub headers : Vec < ( String , String ) > ,
46100}
47101
48102impl IntoResponse for ApiError {
49103 fn into_response ( self ) -> axum:: response:: Response {
50104 let body = serde_json:: json!( { "error" : self . message } ) ;
51- ( self . status , Json ( body) ) . into_response ( )
105+ let mut response = ( self . status , Json ( body) ) . into_response ( ) ;
106+ for ( name, value) in & self . headers {
107+ if let ( Ok ( n) , Ok ( v) ) = (
108+ axum:: http:: header:: HeaderName :: from_bytes ( name. as_bytes ( ) ) ,
109+ HeaderValue :: from_str ( value) ,
110+ ) {
111+ response. headers_mut ( ) . insert ( n, v) ;
112+ }
113+ }
114+ response
52115 }
53116}
54117
@@ -57,30 +120,42 @@ impl ApiError {
57120 Self {
58121 status : StatusCode :: UNAUTHORIZED ,
59122 message : msg. to_string ( ) ,
123+ headers : vec ! [ ] ,
60124 }
61125 }
62126 pub fn forbidden ( msg : & str ) -> Self {
63127 Self {
64128 status : StatusCode :: FORBIDDEN ,
65129 message : msg. to_string ( ) ,
130+ headers : vec ! [ ] ,
66131 }
67132 }
68133 pub fn bad_request ( msg : & str ) -> Self {
69134 Self {
70135 status : StatusCode :: BAD_REQUEST ,
71136 message : msg. to_string ( ) ,
137+ headers : vec ! [ ] ,
72138 }
73139 }
74140 pub fn not_found ( msg : & str ) -> Self {
75141 Self {
76142 status : StatusCode :: NOT_FOUND ,
77143 message : msg. to_string ( ) ,
144+ headers : vec ! [ ] ,
78145 }
79146 }
80147 pub fn internal ( msg : & str ) -> Self {
81148 Self {
82149 status : StatusCode :: INTERNAL_SERVER_ERROR ,
83150 message : msg. to_string ( ) ,
151+ headers : vec ! [ ] ,
152+ }
153+ }
154+ pub fn rate_limited ( retry_after : u64 ) -> Self {
155+ Self {
156+ status : StatusCode :: TOO_MANY_REQUESTS ,
157+ message : format ! ( "Rate limit exceeded. Retry after {retry_after}s" ) ,
158+ headers : vec ! [ ( "retry-after" . to_string( ) , retry_after. to_string( ) ) ] ,
84159 }
85160 }
86161}
@@ -102,13 +177,17 @@ pub fn check_permission(permission: &str, is_write: bool) -> bool {
102177 permission == "write"
103178}
104179
105- /// Extract and validate auth from request headers.
180+ /// Extract and validate auth from request headers, then check rate limit .
106181pub fn authorize (
107182 headers : & axum:: http:: HeaderMap ,
108183 state : & ApiState ,
109184 is_write : bool ,
110185) -> Result < ( ) , ApiError > {
111186 if state. no_auth {
187+ state
188+ . rate_limiter
189+ . check ( "no_auth" )
190+ . map_err ( ApiError :: rate_limited) ?;
112191 return Ok ( ( ) ) ;
113192 }
114193 let auth = headers
@@ -125,6 +204,10 @@ pub fn authorize(
125204 "Insufficient permissions: write access required" ,
126205 ) ) ;
127206 }
207+ state
208+ . rate_limiter
209+ . check ( key)
210+ . map_err ( ApiError :: rate_limited) ?;
128211 Ok ( ( ) )
129212}
130213
@@ -237,12 +320,38 @@ struct DeleteBody {
237320 mode : Option < String > ,
238321}
239322
323+ // ---------------------------------------------------------------------------
324+ // CORS
325+ // ---------------------------------------------------------------------------
326+
327+ fn cors_layer ( origins : & [ String ] ) -> CorsLayer {
328+ if origins. is_empty ( ) {
329+ return CorsLayer :: new ( ) ;
330+ }
331+ if origins. iter ( ) . any ( |o| o == "*" ) {
332+ return CorsLayer :: new ( )
333+ . allow_origin ( Any )
334+ . allow_methods ( Any )
335+ . allow_headers ( Any ) ;
336+ }
337+ let origins: Vec < HeaderValue > = origins
338+ . iter ( )
339+ . filter_map ( |o| o. parse ( ) . ok ( ) )
340+ . collect ( ) ;
341+ CorsLayer :: new ( )
342+ . allow_origin ( origins)
343+ . allow_methods ( [ Method :: GET , Method :: POST , Method :: OPTIONS ] )
344+ . allow_headers ( Any )
345+ . allow_credentials ( true )
346+ }
347+
240348// ---------------------------------------------------------------------------
241349// Router
242350// ---------------------------------------------------------------------------
243351
244352/// Build the axum router with all API endpoints.
245353pub fn build_router ( state : ApiState ) -> Router {
354+ let cors = cors_layer ( & state. http_config . cors_origins ) ;
246355 Router :: new ( )
247356 . route ( "/api/health-check" , get ( health_check) )
248357 . route ( "/api/search" , post ( handle_search) )
@@ -265,6 +374,7 @@ pub fn build_router(state: ApiState) -> Router {
265374 . route ( "/api/unarchive" , post ( handle_unarchive) )
266375 . route ( "/api/update-metadata" , post ( handle_update_metadata) )
267376 . route ( "/api/delete" , post ( handle_delete) )
377+ . layer ( cors)
268378 . with_state ( state)
269379}
270380
@@ -746,7 +856,6 @@ async fn handle_delete(
746856#[ cfg( test) ]
747857mod tests {
748858 use super :: * ;
749- use std:: collections:: HashMap ;
750859 use std:: path:: PathBuf ;
751860 use std:: time:: SystemTime ;
752861
@@ -792,16 +901,19 @@ mod tests {
792901
793902 fn test_api_state ( ) -> ApiState {
794903 let store = Store :: open_memory ( ) . expect ( "in-memory store" ) ;
904+ let config = test_http_config ( ) ;
905+ let rate_limiter = Arc :: new ( RateLimiter :: new ( config. rate_limit ) ) ;
795906 ApiState {
796907 store : Arc :: new ( Mutex :: new ( store) ) ,
797908 embedder : Arc :: new ( Mutex :: new ( Box :: new ( DummyEmbedder ) as Box < dyn EmbedModel + Send > ) ) ,
798909 vault_path : Arc :: new ( PathBuf :: from ( "/tmp/test-vault" ) ) ,
799910 profile : Arc :: new ( None ) ,
800911 orchestrator : None ,
801912 reranker : None ,
802- http_config : Arc :: new ( test_http_config ( ) ) ,
913+ http_config : Arc :: new ( config ) ,
803914 no_auth : false ,
804915 recent_writes : Arc :: new ( Mutex :: new ( HashMap :: < PathBuf , SystemTime > :: new ( ) ) ) ,
916+ rate_limiter,
805917 }
806918 }
807919
@@ -1029,4 +1141,73 @@ mod tests {
10291141 // Should be 500 (file not found via store) but NOT 403
10301142 assert_ne ! ( response. status( ) , StatusCode :: FORBIDDEN ) ;
10311143 }
1144+
1145+ // -----------------------------------------------------------------------
1146+ // Rate limiter unit tests
1147+ // -----------------------------------------------------------------------
1148+
1149+ #[ test]
1150+ fn test_rate_limiter_allows_under_limit ( ) {
1151+ let limiter = RateLimiter :: new ( 5 ) ;
1152+ for _ in 0 ..5 {
1153+ assert ! ( limiter. check( "key1" ) . is_ok( ) ) ;
1154+ }
1155+ }
1156+
1157+ #[ test]
1158+ fn test_rate_limiter_rejects_over_limit ( ) {
1159+ let limiter = RateLimiter :: new ( 2 ) ;
1160+ assert ! ( limiter. check( "key1" ) . is_ok( ) ) ;
1161+ assert ! ( limiter. check( "key1" ) . is_ok( ) ) ;
1162+ assert ! ( limiter. check( "key1" ) . is_err( ) ) ;
1163+ }
1164+
1165+ #[ test]
1166+ fn test_rate_limiter_unlimited ( ) {
1167+ let limiter = RateLimiter :: new ( 0 ) ;
1168+ for _ in 0 ..1000 {
1169+ assert ! ( limiter. check( "key1" ) . is_ok( ) ) ;
1170+ }
1171+ }
1172+
1173+ #[ test]
1174+ fn test_rate_limiter_separate_keys ( ) {
1175+ let limiter = RateLimiter :: new ( 1 ) ;
1176+ assert ! ( limiter. check( "key1" ) . is_ok( ) ) ;
1177+ assert ! ( limiter. check( "key2" ) . is_ok( ) ) ; // different key, separate bucket
1178+ assert ! ( limiter. check( "key1" ) . is_err( ) ) ; // key1 exhausted
1179+ }
1180+
1181+ #[ tokio:: test]
1182+ async fn test_rate_limit_returns_429 ( ) {
1183+ let mut state = test_api_state ( ) ;
1184+ state. rate_limiter = Arc :: new ( RateLimiter :: new ( 1 ) ) ;
1185+ let app = build_router ( state) ;
1186+ // First request passes (consumes the single token)
1187+ let response = app
1188+ . clone ( )
1189+ . oneshot (
1190+ axum:: http:: Request :: builder ( )
1191+ . uri ( "/api/vault-map" )
1192+ . header ( "authorization" , "Bearer eg_readkey" )
1193+ . body ( Body :: empty ( ) )
1194+ . unwrap ( ) ,
1195+ )
1196+ . await
1197+ . unwrap ( ) ;
1198+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
1199+ // Second request gets 429
1200+ let response = app
1201+ . oneshot (
1202+ axum:: http:: Request :: builder ( )
1203+ . uri ( "/api/vault-map" )
1204+ . header ( "authorization" , "Bearer eg_readkey" )
1205+ . body ( Body :: empty ( ) )
1206+ . unwrap ( ) ,
1207+ )
1208+ . await
1209+ . unwrap ( ) ;
1210+ assert_eq ! ( response. status( ) , StatusCode :: TOO_MANY_REQUESTS ) ;
1211+ assert ! ( response. headers( ) . get( "retry-after" ) . is_some( ) ) ;
1212+ }
10321213}
0 commit comments