Skip to content

Commit 3c37b80

Browse files
committed
WIP: Async support in pulp-glue
Replaces requests with aiohttp and changes the api.
1 parent 89ef700 commit 3c37b80

File tree

9 files changed

+118
-98
lines changed

9 files changed

+118
-98
lines changed

CHANGES/pulp-glue/+aiohttp.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WIP: Added async api to Pulp glue.

CHANGES/pulp-glue/+aiohttp.removal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Replaced requests with aiohttp.
2+
Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library.

lint_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ mypy~=1.19.1
44
shellcheck-py~=0.11.0.1
55

66
# Type annotation stubs
7+
types-aiofiles
78
types-pygments
89
types-PyYAML
9-
types-requests
1010
types-setuptools
1111
types-toml
1212

lower_bounds_constraints.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
aiofiles==25.1.0
2+
aiohttp==3.12.0
13
click==8.0.0
24
packaging==22.0
35
PyYAML==5.3

pulp-glue/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ classifiers = [
2222
"Typing :: Typed",
2323
]
2424
dependencies = [
25+
"aiofiles>=25.1.0,<25.2",
26+
"aiohttp>=3.12.0,<3.14",
2527
"multidict>=6.0.5,<6.8",
2628
"packaging>=22.0,<=26.0", # CalVer
27-
"requests>=2.24.0,<2.33",
2829
"tomli>=2.0.0,<2.1;python_version<'3.11'",
2930
]
3031

pulp-glue/src/pulp_glue/common/openapi.py

Lines changed: 88 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from io import BufferedReader
1313
from urllib.parse import urlencode, urljoin
1414

15-
import requests
16-
import urllib3
15+
import aiofiles
16+
import aiofiles.os
17+
import aiohttp
1718
from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping
1819

1920
from pulp_glue.common import __version__
@@ -136,38 +137,13 @@ def __init__(
136137
if cid:
137138
self._headers["Correlation-Id"] = cid
138139

139-
self._setup_session()
140-
141140
self._oauth2_lock = asyncio.Lock()
142141
self._oauth2_token: str | None = None
143142
self._oauth2_expires: datetime = datetime.now()
144143

145144
self._patch_api_hook: t.Callable[[t.Any], t.Any] = patch_api_hook or (lambda data: data)
146145
self.load_api(refresh_cache=refresh_cache)
147146

148-
def _setup_session(self) -> None:
149-
# This is specific requests library.
150-
151-
if self._verify_ssl is False:
152-
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
153-
154-
self._session: requests.Session = requests.session()
155-
# Don't redirect, because carrying auth accross redirects is unsafe.
156-
self._session.max_redirects = 0
157-
self._session.headers.update(self._headers)
158-
session_settings = self._session.merge_environment_settings(
159-
self._base_url, {}, None, self._verify_ssl, None
160-
)
161-
self._session.verify = session_settings["verify"]
162-
self._session.proxies = session_settings["proxies"]
163-
164-
if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS():
165-
cert, key = self._auth_provider.tls_credentials()
166-
if key is not None:
167-
self._session.cert = (cert, key)
168-
else:
169-
self._session.cert = cert
170-
171147
@property
172148
def base_url(self) -> str:
173149
return self._base_url
@@ -191,7 +167,10 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
191167
return _ssl_context
192168

193169
def load_api(self, refresh_cache: bool = False) -> None:
194-
# TODO: Find a way to invalidate caches on upstream change
170+
asyncio.run(self._load_api(refresh_cache=refresh_cache))
171+
172+
async def _load_api(self, refresh_cache: bool = False) -> None:
173+
# TODO: Find a way to invalidate caches on upstream change.
195174
xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache"
196175
apidoc_cache: str = os.path.join(
197176
os.path.expanduser(xdg_cache_home),
@@ -203,17 +182,17 @@ def load_api(self, refresh_cache: bool = False) -> None:
203182
if refresh_cache:
204183
# Fake that we did not find the cache.
205184
raise OSError()
206-
with open(apidoc_cache, "rb") as f:
207-
data: bytes = f.read()
185+
async with aiofiles.open(apidoc_cache, mode="rb") as f:
186+
data: bytes = await f.read()
208187
self._parse_api(data)
209188
except Exception:
210-
# Try again with a freshly downloaded version
211-
data = self._download_api()
189+
# Try again with a freshly downloaded version.
190+
data = await self._download_api()
212191
self._parse_api(data)
213-
# Write to cache as it seems to be valid
214-
os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
215-
with open(apidoc_cache, "bw") as f:
216-
f.write(data)
192+
# Write to cache as it seems to be valid.
193+
await aiofiles.os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
194+
async with aiofiles.open(apidoc_cache, mode="bw") as f:
195+
await f.write(data)
217196

218197
def _parse_api(self, data: bytes) -> None:
219198
raw_spec = self._patch_api_hook(json.loads(data))
@@ -229,15 +208,18 @@ def _parse_api(self, data: bytes) -> None:
229208
if method in METHODS
230209
}
231210

232-
def _download_api(self) -> bytes:
233-
try:
234-
response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path))
235-
except requests.RequestException as e:
236-
raise OpenAPIError(str(e))
237-
response.raise_for_status()
238-
if "Correlation-Id" in response.headers:
239-
self._set_correlation_id(response.headers["Correlation-Id"])
240-
return response.content
211+
async def _download_api(self) -> bytes:
212+
response = await self._send_request(
213+
_Request(
214+
operation_id="",
215+
method="get",
216+
url=urljoin(self._base_url, self._doc_path),
217+
headers=self._headers,
218+
)
219+
)
220+
if response.status_code != 200:
221+
raise OpenAPIError(_("Failed to find api docs."))
222+
return response.body
241223

242224
def _set_correlation_id(self, correlation_id: str) -> None:
243225
if "Correlation-Id" in self._headers:
@@ -249,8 +231,6 @@ def _set_correlation_id(self, correlation_id: str) -> None:
249231
)
250232
else:
251233
self._headers["Correlation-Id"] = correlation_id
252-
# Do it for requests too...
253-
self._session.headers["Correlation-Id"] = correlation_id
254234

255235
def param_spec(
256236
self, operation_id: str, param_type: str, required: bool = False
@@ -467,7 +447,7 @@ def _render_request(
467447
security=security,
468448
)
469449

470-
def _log_request(self, request: _Request) -> None:
450+
async def _log_request(self, request: _Request) -> None:
471451
if request.params:
472452
qs = urlencode(request.params)
473453
self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}?{qs}")
@@ -493,7 +473,6 @@ def _select_proposal(
493473
if (
494474
request.security
495475
and "Authorization" not in request.headers
496-
and "Authorization" not in self._session.headers
497476
and self._auth_provider is not None
498477
):
499478
security_schemes: dict[str, dict[str, t.Any]] = self.api_spec["components"][
@@ -565,7 +544,7 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
565544
headers={"Authorization": f"Basic {secret.decode()}"},
566545
data=data,
567546
)
568-
response = self._send_request(request)
547+
response = await self._send_request(request)
569548
if response.status_code < 200 or response.status_code >= 300:
570549
raise OpenAPIError("Failed to fetch OAuth2 token")
571550
result = json.loads(response.body)
@@ -574,38 +553,55 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
574553
new_token = True
575554
return new_token
576555

577-
def _send_request(
556+
async def _send_request(
578557
self,
579558
request: _Request,
580559
) -> _Response:
581-
# This function uses requests to translate the _Request into a _Response.
560+
# This function uses aiohttp to translate the _Request into a _Response.
561+
data: aiohttp.FormData | dict[str, t.Any] | str | None
562+
if request.files:
563+
assert isinstance(request.data, dict)
564+
# Maybe assert on the content type header.
565+
data = aiohttp.FormData(default_to_multipart=True)
566+
for key, value in request.data.items():
567+
data.add_field(key, encode_param(value))
568+
for key, (name, value, content_type) in request.files.items():
569+
data.add_field(key, value, filename=name, content_type=content_type)
570+
else:
571+
data = request.data
582572
try:
583-
r = self._session.request(
584-
request.method,
585-
request.url,
586-
params=request.params,
587-
headers=request.headers,
588-
data=request.data,
589-
files=request.files,
590-
)
591-
response = _Response(status_code=r.status_code, headers=r.headers, body=r.content)
592-
except requests.TooManyRedirects as e:
593-
assert e.response is not None
573+
async with aiohttp.ClientSession() as session:
574+
async with session.request(
575+
request.method,
576+
request.url,
577+
params=request.params,
578+
headers=request.headers,
579+
data=data,
580+
ssl=self.ssl_context,
581+
max_redirects=0,
582+
) as r:
583+
response_body = await r.read()
584+
response = _Response(
585+
status_code=r.status, headers=r.headers, body=response_body
586+
)
587+
except aiohttp.TooManyRedirects as e:
588+
# We could handle that in the middleware...
589+
assert e.history[-1] is not None
594590
raise OpenAPIError(
595591
_(
596592
"Received redirect to '{new_url} from {old_url}'."
597593
" Please check your configuration."
598594
).format(
599-
new_url=e.response.headers["location"],
595+
new_url=e.history[-1].headers["location"],
600596
old_url=request.url,
601597
)
602598
)
603-
except requests.RequestException as e:
599+
except aiohttp.ClientResponseError as e:
604600
raise OpenAPIError(str(e))
605601

606602
return response
607603

608-
def _log_response(self, response: _Response) -> None:
604+
async def _log_response(self, response: _Response) -> None:
609605
self._debug_callback(
610606
1, _("Response: {status_code}").format(status_code=response.status_code)
611607
)
@@ -652,6 +648,22 @@ def call(
652648
parameters: dict[str, t.Any] | None = None,
653649
body: dict[str, t.Any] | None = None,
654650
validate_body: bool = True,
651+
) -> t.Any:
652+
return asyncio.run(
653+
self.async_call(
654+
operation_id=operation_id,
655+
parameters=parameters,
656+
body=body,
657+
validate_body=validate_body,
658+
)
659+
)
660+
661+
async def async_call(
662+
self,
663+
operation_id: str,
664+
parameters: dict[str, t.Any] | None = None,
665+
body: dict[str, t.Any] | None = None,
666+
validate_body: bool = True,
655667
) -> t.Any:
656668
"""
657669
Make a call to the server.
@@ -706,37 +718,33 @@ def call(
706718
body,
707719
validate_body=validate_body,
708720
)
709-
self._log_request(request)
721+
await self._log_request(request)
710722

711723
if self._dry_run and request.method.upper() not in SAFE_METHODS:
712724
raise UnsafeCallError(_("Call aborted due to safe mode"))
713725

714726
may_retry = False
715727
if proposal := self._select_proposal(request):
716728
assert len(proposal) == 1, "More complex security proposals are not implemented."
717-
may_retry = asyncio.run(self._authenticate_request(request, proposal))
729+
may_retry = await self._authenticate_request(request, proposal)
718730

719-
response = self._send_request(request)
731+
response = await self._send_request(request)
720732

721733
if proposal is not None:
722734
assert self._auth_provider is not None
723735
if may_retry and response.status_code == 401:
724736
self._oauth2_token = None
725-
asyncio.run(self._authenticate_request(request, proposal))
726-
response = self._send_request(request)
737+
await self._authenticate_request(request, proposal)
738+
response = await self._send_request(request)
727739

728740
if response.status_code >= 200 and response.status_code < 300:
729-
asyncio.run(
730-
self._auth_provider.auth_success_hook(
731-
proposal, self.api_spec["components"]["securitySchemes"]
732-
)
741+
await self._auth_provider.auth_success_hook(
742+
proposal, self.api_spec["components"]["securitySchemes"]
733743
)
734744
elif response.status_code == 401:
735-
asyncio.run(
736-
self._auth_provider.auth_failure_hook(
737-
proposal, self.api_spec["components"]["securitySchemes"]
738-
)
745+
await self._auth_provider.auth_failure_hook(
746+
proposal, self.api_spec["components"]["securitySchemes"]
739747
)
740748

741-
self._log_response(response)
749+
await self._log_response(response)
742750
return self._parse_response(method_spec, response)

pulp-glue/tests/test_auth_provider.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def test_can_complete_basic(self, provider: AuthProviderBase) -> None:
6363
assert provider.can_complete_http_basic()
6464

6565
def test_provides_username_and_password(self, provider: AuthProviderBase) -> None:
66-
assert asyncio.run(provider.http_basic_credentials()) == (
67-
b"user1",
68-
b"password1",
69-
)
66+
assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"password1")
7067

7168
def test_cannot_complete_mutualTLS(self, provider: AuthProviderBase) -> None:
7269
assert not provider.can_complete_mutualTLS()
@@ -104,10 +101,7 @@ def test_client_id_needs_client_secret(self) -> None:
104101
def test_can_complete_oauth2_client_credentials_and_provide_them(self) -> None:
105102
provider = GlueAuthProvider(client_id="client1", client_secret="secret1")
106103
assert provider.can_complete_oauth2_client_credentials([]) is True
107-
assert asyncio.run(provider.oauth2_client_credentials()) == (
108-
b"client1",
109-
b"secret1",
110-
)
104+
assert asyncio.run(provider.oauth2_client_credentials()) == (b"client1", b"secret1")
111105

112106
def test_can_complete_mutualTLS_and_provide_cert(self) -> None:
113107
provider = GlueAuthProvider(cert="FAKECERTIFICATE")

pulp-glue/tests/test_openapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
).encode()
9999

100100

101-
def mock_send_request(request: _Request) -> _Response:
101+
async def mock_send_request(request: _Request) -> _Response:
102102
if request.url.endswith("oauth/token"):
103103
assert request.method.lower() == "post"
104104
# $ echo -n "client1:secret1" | base64

0 commit comments

Comments
 (0)