1- import asyncio
21import json
32import logging
4- from queue import Queue
5- import threading
6- import time
73from http import HTTPStatus
8- from http .server import HTTPServer , SimpleHTTPRequestHandler
9- from typing import Type , Optional
10- from unittest import TestCase
11- from urllib .parse import urlparse , parse_qs , ParseResult
4+ from http .server import SimpleHTTPRequestHandler
5+ from typing import Optional
6+ from urllib .parse import ParseResult , parse_qs , urlparse
127
138INVALID_AUTH = json .dumps (
149 {
@@ -93,7 +88,6 @@ class MockHandler(SimpleHTTPRequestHandler):
9388 protocol_version = "HTTP/1.1"
9489 default_request_version = "HTTP/1.1"
9590 logger = logging .getLogger (__name__ )
96- received_requests = {}
9791
9892 def is_valid_token (self ):
9993 return "Authorization" in self .headers and str (self .headers ["Authorization" ]).startswith ("Bearer xoxb-" )
@@ -109,8 +103,8 @@ def set_common_headers(self, content_length: int = 0):
109103 def _handle (self ):
110104 parsed_path : ParseResult = urlparse (self .path )
111105 path = parsed_path .path
112- self . server . queue . put ( path )
113- self .received_requests [ path ] = self . received_requests . get (path , 0 ) + 1
106+ # put_nowait is common between Queue & asyncio.Queue, it does not need to be awaited
107+ self .server . queue . put_nowait (path )
114108 try :
115109 if path == "/webhook" :
116110 self .send_response (200 )
@@ -208,95 +202,3 @@ def _parse_request_body(self, parsed_path: str, content_len: int) -> Optional[di
208202 if parsed_path and parsed_path .query :
209203 request_body = {k : v [0 ] for k , v in parse_qs (parsed_path .query ).items ()}
210204 return request_body
211-
212-
213- class MockServerThread (threading .Thread ):
214- def __init__ (self , queue : Queue , test : TestCase , handler : Type [SimpleHTTPRequestHandler ] = MockHandler ):
215- threading .Thread .__init__ (self )
216- self .handler = handler
217- self .test = test
218- self .queue = queue
219-
220- def run (self ):
221- self .server = HTTPServer (("localhost" , 8888 ), self .handler )
222- self .server .queue = self .queue
223- self .test .mock_received_requests = self .handler .received_requests
224- self .test .server_url = "http://localhost:8888"
225- self .test .host , self .test .port = self .server .socket .getsockname ()
226- self .test .server_started .set () # threading.Event()
227-
228- self .test = None
229- try :
230- self .server .serve_forever (0.05 )
231- finally :
232- self .server .server_close ()
233-
234- def stop (self ):
235- self .handler .received_requests = {}
236- with self .server .queue .mutex :
237- del self .server .queue
238- self .server .shutdown ()
239- self .join ()
240-
241-
242- class ReceivedRequests :
243- def __init__ (self , queue : Queue ):
244- self .queue = queue
245- self .received_requests = {}
246-
247- def get (self , key : str , default : Optional [int ] = None ) -> Optional [int ]:
248- while not self .queue .empty ():
249- path = self .queue .get ()
250- self .received_requests [path ] = self .received_requests .get (path , 0 ) + 1
251- return self .received_requests .get (key , default )
252-
253-
254- def setup_mock_web_api_server (test : TestCase ):
255- test .server_started = threading .Event ()
256- test .received_requests = ReceivedRequests (Queue ())
257- test .thread = MockServerThread (test .received_requests .queue , test )
258- test .thread .start ()
259- test .server_started .wait ()
260-
261-
262- def cleanup_mock_web_api_server (test : TestCase ):
263- test .thread .stop ()
264- test .thread = None
265-
266-
267- def assert_received_request_count (test : TestCase , path : str , min_count : int , timeout : float = 1 ):
268- start_time = time .time ()
269- error = None
270- while time .time () - start_time < timeout :
271- try :
272- assert test .received_requests .get (path , 0 ) == min_count
273- return
274- except Exception as e :
275- error = e
276- # waiting for some requests to be received
277- time .sleep (0.05 )
278-
279- if error is not None :
280- raise error
281-
282-
283- def assert_auth_test_count (test : TestCase , expected_count : int ):
284- assert_received_request_count (test , "/auth.test" , expected_count , 0.5 )
285-
286-
287- async def assert_auth_test_count_async (test : TestCase , expected_count : int ):
288- await asyncio .sleep (0.1 )
289- retry_count = 0
290- error = None
291- while retry_count < 3 :
292- try :
293- test .mock_received_requests .get ("/auth.test" , 0 ) == expected_count
294- break
295- except Exception as e :
296- error = e
297- retry_count += 1
298- # waiting for mock_received_requests updates
299- await asyncio .sleep (0.1 )
300-
301- if error is not None :
302- raise error
0 commit comments