Skip to content

Commit 5b3c27e

Browse files
Merge pull request #991 from tiran/override-resolver-url
feat: add override_download_url to resolvers
2 parents 5f5519a + 35798ae commit 5b3c27e

File tree

2 files changed

+150
-15
lines changed

2 files changed

+150
-15
lines changed

src/fromager/resolver.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections.abc import Iterable
1515
from operator import attrgetter
1616
from platform import python_version
17-
from urllib.parse import quote, unquote, urljoin, urlparse
17+
from urllib.parse import quote, unquote, urlparse
1818

1919
import pypi_simple
2020
import resolvelib
@@ -192,6 +192,8 @@ def get_project_from_pypi(
192192
extras: typing.Iterable[str],
193193
sdist_server_url: str,
194194
ignore_platform: bool = False,
195+
*,
196+
override_download_url: str | None = None,
195197
) -> Candidates:
196198
"""Return candidates created from the project name and extras."""
197199
found_candidates: set[str] = set()
@@ -352,14 +354,19 @@ def get_project_from_pypi(
352354
ignored_candidates.add(dp.filename)
353355
continue
354356

357+
if override_download_url is None:
358+
url = dp.url
359+
else:
360+
url = override_download_url.format(version=version)
361+
355362
upload_time = dp.upload_time
356363
if upload_time is not None:
357364
upload_time = upload_time.astimezone(datetime.UTC)
358365

359366
c = Candidate(
360367
name=name,
361368
version=version,
362-
url=dp.url,
369+
url=url,
363370
extras=tuple(sorted(extras)),
364371
is_sdist=is_sdist,
365372
build_tag=build_tag,
@@ -592,7 +599,11 @@ def find_matches(
592599

593600

594601
class PyPIProvider(BaseProvider):
595-
"""Lookup package and versions from a simple Python index (PyPI)"""
602+
"""Lookup package and versions from a simple Python index (PyPI)
603+
604+
The ``override_download_url`` parameter supports the string template variable:
605+
* version (Version object)
606+
"""
596607

597608
provider_description: typing.ClassVar[str] = (
598609
"PyPI resolver (searching at {self.sdist_server_url})"
@@ -608,6 +619,7 @@ def __init__(
608619
ignore_platform: bool = False,
609620
*,
610621
use_resolver_cache: bool = True,
622+
override_download_url: str | None = None,
611623
):
612624
super().__init__(
613625
constraints=constraints,
@@ -618,21 +630,25 @@ def __init__(
618630
self.include_wheels = include_wheels
619631
self.sdist_server_url = sdist_server_url
620632
self.ignore_platform = ignore_platform
633+
self.override_download_url = override_download_url
621634

622635
@property
623636
def cache_key(self) -> str:
624637
# ignore platform parameter changes behavior of find_candidates()
638+
key = self.sdist_server_url
639+
if self.override_download_url is not None:
640+
key = f"{key}+{self.override_download_url}"
625641
if self.ignore_platform:
626-
return f"{self.sdist_server_url}+ignore_platform"
627-
else:
628-
return self.sdist_server_url
642+
key = f"{key}+ignore_platform"
643+
return key
629644

630645
def find_candidates(self, identifier: str) -> Candidates:
631646
return get_project_from_pypi(
632647
identifier,
633-
set(),
634-
self.sdist_server_url,
635-
self.ignore_platform,
648+
extras=set(),
649+
sdist_server_url=self.sdist_server_url,
650+
ignore_platform=self.ignore_platform,
651+
override_download_url=self.override_download_url,
636652
)
637653

638654
def validate_candidate(
@@ -791,6 +807,12 @@ class GitHubTagProvider(GenericProvider):
791807
"""Lookup tarball and version from GitHub git tags
792808
793809
Assumes that upstream uses version tags `1.2.3` or `v1.2.3`.
810+
811+
The ``override_download_url`` parameter supports the string template variable:
812+
* organization
813+
* repo
814+
* tagname
815+
* version (Version object)
794816
"""
795817

796818
provider_description: typing.ClassVar[str] = (
@@ -808,6 +830,7 @@ def __init__(
808830
*,
809831
req_type: RequirementType | None = None,
810832
use_resolver_cache: bool = True,
833+
override_download_url: str | None = None,
811834
):
812835
super().__init__(
813836
constraints=constraints,
@@ -818,10 +841,14 @@ def __init__(
818841
)
819842
self.organization = organization
820843
self.repo = repo
844+
self.override_download_url = override_download_url
821845

822846
@property
823847
def cache_key(self) -> str:
824-
return f"{self.organization}/{self.repo}"
848+
key = f"{self.organization}/{self.repo}"
849+
if self.override_download_url is not None:
850+
key = f"{key}+{self.override_download_url}"
851+
return key
825852

826853
@retry_on_exception(
827854
exceptions=RETRYABLE_EXCEPTIONS,
@@ -852,7 +879,16 @@ def _find_tags(
852879
logger.debug(f"{identifier}: match function ignores {tagname}")
853880
continue
854881
assert isinstance(version, Version)
855-
url = entry["tarball_url"]
882+
883+
if self.override_download_url is None:
884+
url = entry["tarball_url"]
885+
else:
886+
url = self.override_download_url.format(
887+
organization=self.organization,
888+
repo=self.repo,
889+
tagname=tagname,
890+
version=version,
891+
)
856892

857893
# Github tag API endpoint does not include commit date information.
858894
# It would be too expensive to query every commit API endpoint.
@@ -870,7 +906,15 @@ def _find_tags(
870906

871907

872908
class GitLabTagProvider(GenericProvider):
873-
"""Lookup tarball and version from GitLab git tags"""
909+
"""Lookup tarball and version from GitLab git tags
910+
911+
The ``override_download_url`` parameter supports the string template variable:
912+
* hostname
913+
* project_path
914+
* project_name (last component of project_path)
915+
* tagname
916+
* version (Version object)
917+
"""
874918

875919
provider_description: typing.ClassVar[str] = (
876920
"GitLab tag resolver (project: {self.server_url}/{self.project_path})"
@@ -885,6 +929,7 @@ def __init__(
885929
*,
886930
req_type: RequirementType | None = None,
887931
use_resolver_cache: bool = True,
932+
override_download_url: str | None = None,
888933
) -> None:
889934
super().__init__(
890935
constraints=constraints,
@@ -894,6 +939,9 @@ def __init__(
894939
matcher=matcher,
895940
)
896941
self.server_url = server_url.rstrip("/")
942+
self.server_hostname = urlparse(server_url).hostname
943+
if not self.server_hostname:
944+
raise ValueError(f"invalid {server_url=}")
897945
self.project_path = project_path.lstrip("/")
898946
# URL-encode the project path as required by GitLab API.
899947
# The safe="" parameter tells quote() to encode ALL characters,
@@ -904,10 +952,14 @@ def __init__(
904952
self.api_url = (
905953
f"{self.server_url}/api/v4/projects/{encoded_path}/repository/tags"
906954
)
955+
self.override_download_url = override_download_url
907956

908957
@property
909958
def cache_key(self) -> str:
910-
return f"{self.server_url}/{self.project_path}"
959+
key = f"{self.server_url}/{self.project_path}"
960+
if self.override_download_url is not None:
961+
key = f"{key}+{self.override_download_url}"
962+
return key
911963

912964
@retry_on_exception(
913965
exceptions=RETRYABLE_EXCEPTIONS,
@@ -921,6 +973,14 @@ def _find_tags(
921973
) -> Iterable[Candidate]:
922974
nexturl: str = self.api_url
923975
created_at: datetime.datetime | None
976+
project_name = self.project_path.split("/")[-1]
977+
if self.override_download_url is None:
978+
download_template = (
979+
self.server_url
980+
+ "/{project_path}/-/archive/{tagname}/{project_name}-{tagname}.tar.gz"
981+
)
982+
else:
983+
download_template = self.override_download_url
924984
while nexturl:
925985
resp: Response = session.get(nexturl)
926986
resp.raise_for_status()
@@ -932,8 +992,13 @@ def _find_tags(
932992
continue
933993
assert isinstance(version, Version)
934994

935-
archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz"
936-
url = urljoin(self.server_url, archive_path)
995+
url = download_template.format(
996+
hostname=self.server_hostname,
997+
project_path=self.project_path,
998+
project_name=project_name,
999+
tagname=tagname,
1000+
version=version,
1001+
)
9371002

9381003
# get tag creation time, fall back to commit creation time
9391004
created_at_str: str | None = entry.get("created_at")

tests/test_resolver.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,26 @@ def test_provider_constraint_match() -> None:
372372
assert str(candidate.version) == "1.2.2"
373373

374374

375+
def test_pypi_provider_override_download_url() -> None:
376+
with requests_mock.Mocker() as r:
377+
r.get(
378+
"https://pypi.org/simple/hydra-core/",
379+
text=_hydra_core_simple_response,
380+
)
381+
382+
provider = resolver.PyPIProvider(
383+
override_download_url="https://server.test/hydr_core-{version}.tar.gz"
384+
)
385+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
386+
rslvr = resolvelib.Resolver(provider, reporter)
387+
388+
result = rslvr.resolve([Requirement("hydra-core")])
389+
assert "hydra-core" in result.mapping
390+
391+
candidate = result.mapping["hydra-core"]
392+
assert candidate.url == "https://server.test/hydr_core-1.3.2.tar.gz"
393+
394+
375395
_ignore_platform_simple_response = """
376396
<!DOCTYPE html>
377397
<html>
@@ -717,6 +737,33 @@ def test_resolve_github() -> None:
717737
)
718738

719739

740+
def test_resolve_github_override_download_url() -> None:
741+
with requests_mock.Mocker() as r:
742+
r.get(
743+
"https://api.github.com:443/repos/python-wheel-build/fromager",
744+
text=_github_fromager_repo_response,
745+
)
746+
r.get(
747+
"https://api.github.com:443/repos/python-wheel-build/fromager/tags",
748+
text=_github_fromager_tag_response,
749+
)
750+
751+
provider = resolver.GitHubTagProvider(
752+
organization="python-wheel-build",
753+
repo="fromager",
754+
override_download_url="git+https://github.com/{organization}/{repo}.git@{tagname}",
755+
)
756+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
757+
rslvr = resolvelib.Resolver(provider, reporter)
758+
759+
result = rslvr.resolve([Requirement("fromager")])
760+
candidate = result.mapping["fromager"]
761+
assert (
762+
str(candidate.url)
763+
== "git+https://github.com/python-wheel-build/fromager.git@0.9.0"
764+
)
765+
766+
720767
def test_github_constraint_mismatch() -> None:
721768
constraint = constraints.Constraints()
722769
constraint.add_constraint("fromager>=1.0")
@@ -1005,6 +1052,29 @@ def test_resolve_gitlab() -> None:
10051052
)
10061053

10071054

1055+
def test_resolve_gitlab_override_download_url() -> None:
1056+
with requests_mock.Mocker() as r:
1057+
r.get(
1058+
"https://gitlab.com/api/v4/projects/mirrors%2Fgithub%2Fdecile-team%2Fsubmodlib/repository/tags",
1059+
text=_gitlab_submodlib_repo_response,
1060+
)
1061+
1062+
provider = resolver.GitLabTagProvider(
1063+
project_path="mirrors/github/decile-team/submodlib",
1064+
server_url="https://gitlab.com",
1065+
matcher=re.compile("v(.*)"), # with match object
1066+
override_download_url="git+https://{hostname}/{project_path}.git@{tagname}",
1067+
)
1068+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
1069+
rslvr = resolvelib.Resolver(provider, reporter)
1070+
result = rslvr.resolve([Requirement("submodlib")])
1071+
candidate = result.mapping["submodlib"]
1072+
assert (
1073+
str(candidate.url)
1074+
== "git+https://gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3"
1075+
)
1076+
1077+
10081078
def test_gitlab_constraint_mismatch() -> None:
10091079
constraint = constraints.Constraints()
10101080
constraint.add_constraint("submodlib>=1.0")

0 commit comments

Comments
 (0)