1212from io import BufferedReader
1313from urllib .parse import urlencode , urljoin
1414
15- import requests
16- import urllib3
15+ import aiofiles
16+ import aiofiles .os
17+ import aiohttp
1718from multidict import CIMultiDict , CIMultiDictProxy , MutableMultiMapping
1819
1920from pulp_glue .common import __version__
@@ -136,38 +137,13 @@ def __init__(
136137 if cid :
137138 self ._headers ["Correlation-Id" ] = cid
138139
139- self ._setup_session ()
140-
141140 self ._oauth2_lock = asyncio .Lock ()
142141 self ._oauth2_token : str | None = None
143142 self ._oauth2_expires : datetime = datetime .now ()
144143
145144 self ._patch_api_hook : t .Callable [[t .Any ], t .Any ] = patch_api_hook or (lambda data : data )
146145 self .load_api (refresh_cache = refresh_cache )
147146
148- def _setup_session (self ) -> None :
149- # This is specific requests library.
150-
151- if self ._verify_ssl is False :
152- urllib3 .disable_warnings (urllib3 .exceptions .InsecureRequestWarning )
153-
154- self ._session : requests .Session = requests .session ()
155- # Don't redirect, because carrying auth accross redirects is unsafe.
156- self ._session .max_redirects = 0
157- self ._session .headers .update (self ._headers )
158- session_settings = self ._session .merge_environment_settings (
159- self ._base_url , {}, None , self ._verify_ssl , None
160- )
161- self ._session .verify = session_settings ["verify" ]
162- self ._session .proxies = session_settings ["proxies" ]
163-
164- if self ._auth_provider is not None and self ._auth_provider .can_complete_mutualTLS ():
165- cert , key = self ._auth_provider .tls_credentials ()
166- if key is not None :
167- self ._session .cert = (cert , key )
168- else :
169- self ._session .cert = cert
170-
171147 @property
172148 def base_url (self ) -> str :
173149 return self ._base_url
@@ -191,7 +167,10 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
191167 return _ssl_context
192168
193169 def load_api (self , refresh_cache : bool = False ) -> None :
194- # TODO: Find a way to invalidate caches on upstream change
170+ asyncio .run (self ._load_api (refresh_cache = refresh_cache ))
171+
172+ async def _load_api (self , refresh_cache : bool = False ) -> None :
173+ # TODO: Find a way to invalidate caches on upstream change.
195174 xdg_cache_home : str = os .environ .get ("XDG_CACHE_HOME" ) or "~/.cache"
196175 apidoc_cache : str = os .path .join (
197176 os .path .expanduser (xdg_cache_home ),
@@ -203,17 +182,17 @@ def load_api(self, refresh_cache: bool = False) -> None:
203182 if refresh_cache :
204183 # Fake that we did not find the cache.
205184 raise OSError ()
206- with open (apidoc_cache , "rb" ) as f :
207- data : bytes = f .read ()
185+ async with aiofiles . open (apidoc_cache , mode = "rb" ) as f :
186+ data : bytes = await f .read ()
208187 self ._parse_api (data )
209188 except Exception :
210- # Try again with a freshly downloaded version
211- data = self ._download_api ()
189+ # Try again with a freshly downloaded version.
190+ data = await self ._download_api ()
212191 self ._parse_api (data )
213- # Write to cache as it seems to be valid
214- os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
215- with open (apidoc_cache , "bw" ) as f :
216- f .write (data )
192+ # Write to cache as it seems to be valid.
193+ await aiofiles . os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
194+ async with aiofiles . open (apidoc_cache , mode = "bw" ) as f :
195+ await f .write (data )
217196
218197 def _parse_api (self , data : bytes ) -> None :
219198 raw_spec = self ._patch_api_hook (json .loads (data ))
@@ -229,15 +208,18 @@ def _parse_api(self, data: bytes) -> None:
229208 if method in METHODS
230209 }
231210
232- def _download_api (self ) -> bytes :
233- try :
234- response : requests .Response = self ._session .get (urljoin (self ._base_url , self ._doc_path ))
235- except requests .RequestException as e :
236- raise OpenAPIError (str (e ))
237- response .raise_for_status ()
238- if "Correlation-Id" in response .headers :
239- self ._set_correlation_id (response .headers ["Correlation-Id" ])
240- return response .content
211+ async def _download_api (self ) -> bytes :
212+ response = await self ._send_request (
213+ _Request (
214+ operation_id = "" ,
215+ method = "get" ,
216+ url = urljoin (self ._base_url , self ._doc_path ),
217+ headers = self ._headers ,
218+ )
219+ )
220+ if response .status_code != 200 :
221+ raise OpenAPIError (_ ("Failed to find api docs." ))
222+ return response .body
241223
242224 def _set_correlation_id (self , correlation_id : str ) -> None :
243225 if "Correlation-Id" in self ._headers :
@@ -249,8 +231,6 @@ def _set_correlation_id(self, correlation_id: str) -> None:
249231 )
250232 else :
251233 self ._headers ["Correlation-Id" ] = correlation_id
252- # Do it for requests too...
253- self ._session .headers ["Correlation-Id" ] = correlation_id
254234
255235 def param_spec (
256236 self , operation_id : str , param_type : str , required : bool = False
@@ -467,7 +447,7 @@ def _render_request(
467447 security = security ,
468448 )
469449
470- def _log_request (self , request : _Request ) -> None :
450+ async def _log_request (self , request : _Request ) -> None :
471451 if request .params :
472452 qs = urlencode (request .params )
473453 self ._debug_callback (1 , f"{ request .operation_id } : { request .method } { request .url } ?{ qs } " )
@@ -493,7 +473,6 @@ def _select_proposal(
493473 if (
494474 request .security
495475 and "Authorization" not in request .headers
496- and "Authorization" not in self ._session .headers
497476 and self ._auth_provider is not None
498477 ):
499478 security_schemes : dict [str , dict [str , t .Any ]] = self .api_spec ["components" ][
@@ -565,7 +544,7 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
565544 headers = {"Authorization" : f"Basic { secret .decode ()} " },
566545 data = data ,
567546 )
568- response = self ._send_request (request )
547+ response = await self ._send_request (request )
569548 if response .status_code < 200 or response .status_code >= 300 :
570549 raise OpenAPIError ("Failed to fetch OAuth2 token" )
571550 result = json .loads (response .body )
@@ -574,38 +553,55 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
574553 new_token = True
575554 return new_token
576555
577- def _send_request (
556+ async def _send_request (
578557 self ,
579558 request : _Request ,
580559 ) -> _Response :
581- # This function uses requests to translate the _Request into a _Response.
560+ # This function uses aiohttp to translate the _Request into a _Response.
561+ data : aiohttp .FormData | dict [str , t .Any ] | str | None
562+ if request .files :
563+ assert isinstance (request .data , dict )
564+ # Maybe assert on the content type header.
565+ data = aiohttp .FormData (default_to_multipart = True )
566+ for key , value in request .data .items ():
567+ data .add_field (key , encode_param (value ))
568+ for key , (name , value , content_type ) in request .files .items ():
569+ data .add_field (key , value , filename = name , content_type = content_type )
570+ else :
571+ data = request .data
582572 try :
583- r = self ._session .request (
584- request .method ,
585- request .url ,
586- params = request .params ,
587- headers = request .headers ,
588- data = request .data ,
589- files = request .files ,
590- )
591- response = _Response (status_code = r .status_code , headers = r .headers , body = r .content )
592- except requests .TooManyRedirects as e :
593- assert e .response is not None
573+ async with aiohttp .ClientSession () as session :
574+ async with session .request (
575+ request .method ,
576+ request .url ,
577+ params = request .params ,
578+ headers = request .headers ,
579+ data = data ,
580+ ssl = self .ssl_context ,
581+ max_redirects = 0 ,
582+ ) as r :
583+ response_body = await r .read ()
584+ response = _Response (
585+ status_code = r .status , headers = r .headers , body = response_body
586+ )
587+ except aiohttp .TooManyRedirects as e :
588+ # We could handle that in the middleware...
589+ assert e .history [- 1 ] is not None
594590 raise OpenAPIError (
595591 _ (
596592 "Received redirect to '{new_url} from {old_url}'."
597593 " Please check your configuration."
598594 ).format (
599- new_url = e .response .headers ["location" ],
595+ new_url = e .history [ - 1 ] .headers ["location" ],
600596 old_url = request .url ,
601597 )
602598 )
603- except requests . RequestException as e :
599+ except aiohttp . ClientResponseError as e :
604600 raise OpenAPIError (str (e ))
605601
606602 return response
607603
608- def _log_response (self , response : _Response ) -> None :
604+ async def _log_response (self , response : _Response ) -> None :
609605 self ._debug_callback (
610606 1 , _ ("Response: {status_code}" ).format (status_code = response .status_code )
611607 )
@@ -652,6 +648,22 @@ def call(
652648 parameters : dict [str , t .Any ] | None = None ,
653649 body : dict [str , t .Any ] | None = None ,
654650 validate_body : bool = True ,
651+ ) -> t .Any :
652+ return asyncio .run (
653+ self .async_call (
654+ operation_id = operation_id ,
655+ parameters = parameters ,
656+ body = body ,
657+ validate_body = validate_body ,
658+ )
659+ )
660+
661+ async def async_call (
662+ self ,
663+ operation_id : str ,
664+ parameters : dict [str , t .Any ] | None = None ,
665+ body : dict [str , t .Any ] | None = None ,
666+ validate_body : bool = True ,
655667 ) -> t .Any :
656668 """
657669 Make a call to the server.
@@ -706,37 +718,33 @@ def call(
706718 body ,
707719 validate_body = validate_body ,
708720 )
709- self ._log_request (request )
721+ await self ._log_request (request )
710722
711723 if self ._dry_run and request .method .upper () not in SAFE_METHODS :
712724 raise UnsafeCallError (_ ("Call aborted due to safe mode" ))
713725
714726 may_retry = False
715727 if proposal := self ._select_proposal (request ):
716728 assert len (proposal ) == 1 , "More complex security proposals are not implemented."
717- may_retry = asyncio . run ( self ._authenticate_request (request , proposal ) )
729+ may_retry = await self ._authenticate_request (request , proposal )
718730
719- response = self ._send_request (request )
731+ response = await self ._send_request (request )
720732
721733 if proposal is not None :
722734 assert self ._auth_provider is not None
723735 if may_retry and response .status_code == 401 :
724736 self ._oauth2_token = None
725- asyncio . run ( self ._authenticate_request (request , proposal ) )
726- response = self ._send_request (request )
737+ await self ._authenticate_request (request , proposal )
738+ response = await self ._send_request (request )
727739
728740 if response .status_code >= 200 and response .status_code < 300 :
729- asyncio .run (
730- self ._auth_provider .auth_success_hook (
731- proposal , self .api_spec ["components" ]["securitySchemes" ]
732- )
741+ await self ._auth_provider .auth_success_hook (
742+ proposal , self .api_spec ["components" ]["securitySchemes" ]
733743 )
734744 elif response .status_code == 401 :
735- asyncio .run (
736- self ._auth_provider .auth_failure_hook (
737- proposal , self .api_spec ["components" ]["securitySchemes" ]
738- )
745+ await self ._auth_provider .auth_failure_hook (
746+ proposal , self .api_spec ["components" ]["securitySchemes" ]
739747 )
740748
741- self ._log_response (response )
749+ await self ._log_response (response )
742750 return self ._parse_response (method_spec , response )
0 commit comments