Skip to content

Commit 0635c76

Browse files
authored
Merge pull request #981 from rd4398/wrap-version-map
feat(resolver): add VersionMapProvider for custom version resolution
2 parents 6011d21 + a5fafd9 commit 0635c76

4 files changed

Lines changed: 174 additions & 1 deletion

File tree

src/fromager/resolver.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .http_retry import RETRYABLE_EXCEPTIONS, retry_on_exception
3838
from .request_session import session
3939
from .requirements_file import RequirementType
40+
from .versionmap import VersionMap
4041

4142
if typing.TYPE_CHECKING:
4243
from . import context
@@ -109,7 +110,13 @@ def default_resolver_provider(
109110
include_wheels: bool,
110111
req_type: RequirementType | None = None,
111112
ignore_platform: bool = False,
112-
) -> PyPIProvider | GenericProvider | GitHubTagProvider:
113+
) -> (
114+
PyPIProvider
115+
| GenericProvider
116+
| GitHubTagProvider
117+
| GitLabTagProvider
118+
| VersionMapProvider
119+
):
113120
"""Lookup resolver provider to resolve package versions"""
114121
return PyPIProvider(
115122
include_sdists=include_sdists,
@@ -951,3 +958,55 @@ def _find_tags(
951958

952959
# GitLab API uses Link headers for pagination
953960
nexturl = resp.links.get("next", {}).get("url")
961+
962+
963+
class VersionMapProvider(BaseProvider):
964+
"""Lookup package versions from a VersionMap
965+
966+
This provider wraps a VersionMap instance to provide versions and URLs
967+
for package resolution. The VersionMap should contain Version keys mapped
968+
to URL strings.
969+
"""
970+
971+
provider_description: typing.ClassVar[str] = (
972+
"VersionMap resolver (package: {self.package_name})"
973+
)
974+
975+
def __init__(
976+
self,
977+
version_map: VersionMap,
978+
package_name: str,
979+
constraints: Constraints | None = None,
980+
*,
981+
req_type: RequirementType | None = None,
982+
use_resolver_cache: bool = True,
983+
) -> None:
984+
super().__init__(
985+
constraints=constraints,
986+
req_type=req_type,
987+
use_resolver_cache=use_resolver_cache,
988+
)
989+
self.version_map = version_map
990+
self.package_name = package_name
991+
992+
@property
993+
def cache_key(self) -> str:
994+
return f"versionmap:{self.package_name}"
995+
996+
def find_candidates(self, identifier: str) -> Candidates:
997+
"""Find candidates from the VersionMap
998+
999+
Iterates through all versions in the VersionMap and creates Candidate
1000+
objects with the associated URLs.
1001+
"""
1002+
candidates: list[Candidate] = []
1003+
for version in self.version_map.versions():
1004+
url = self.version_map[version]
1005+
candidate = Candidate(
1006+
name=identifier,
1007+
version=version,
1008+
url=url,
1009+
)
1010+
candidates.append(candidate)
1011+
1012+
return candidates

src/fromager/versionmap.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def add(self, key: Version | str, value: typing.Any) -> None:
3030
key = Version(key)
3131
self._content[key] = value
3232

33+
def __getitem__(self, key: Version | str) -> typing.Any:
34+
"""Get the value associated with a version
35+
36+
String keys are converted to Version instances. Raises KeyError if the
37+
version is not found.
38+
"""
39+
if not isinstance(key, Version):
40+
key = Version(key)
41+
return self._content[key]
42+
3343
def versions(self) -> typing.Iterable[Version]:
3444
"""Return the known versions, sorted in descending order."""
3545
return reversed(sorted(self._content.keys()))

tests/test_resolver.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,87 @@ def _versions(*args: typing.Any, **kwds: typing.Any) -> list[tuple[str, str]]:
792792
assert provider.cache_key
793793

794794

795+
def test_resolve_versionmap() -> None:
796+
from fromager.versionmap import VersionMap
797+
798+
version_map = VersionMap(
799+
{
800+
"1.2": "https://example.com/pkg-1.2.tar.gz",
801+
"1.3": "https://example.com/pkg-1.3.tar.gz",
802+
"1.4.1": "https://example.com/pkg-1.4.1.tar.gz",
803+
}
804+
)
805+
806+
provider = resolver.VersionMapProvider(
807+
version_map=version_map, package_name="testpkg"
808+
)
809+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
810+
rslvr = resolvelib.Resolver(provider, reporter)
811+
812+
result = rslvr.resolve([Requirement("testpkg")])
813+
assert "testpkg" in result.mapping
814+
815+
candidate = result.mapping["testpkg"]
816+
assert str(candidate.version) == "1.4.1"
817+
assert candidate.url == "https://example.com/pkg-1.4.1.tar.gz"
818+
819+
# VersionMapProvider uses resolver cache by default
820+
cache = resolver.BaseProvider.resolver_cache
821+
assert "testpkg" in cache
822+
cached_candidates = cache["testpkg"][
823+
(resolver.VersionMapProvider, "versionmap:testpkg")
824+
]
825+
assert len(cached_candidates) == 3
826+
827+
828+
def test_resolve_versionmap_with_constraint() -> None:
829+
from fromager.versionmap import VersionMap
830+
831+
version_map = VersionMap(
832+
{
833+
"1.2": "https://example.com/pkg-1.2.tar.gz",
834+
"1.3": "https://example.com/pkg-1.3.tar.gz",
835+
"1.4.1": "https://example.com/pkg-1.4.1.tar.gz",
836+
}
837+
)
838+
839+
c = constraints.Constraints()
840+
c.add_constraint("testpkg<1.4")
841+
842+
provider = resolver.VersionMapProvider(
843+
version_map=version_map, package_name="testpkg", constraints=c
844+
)
845+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
846+
rslvr = resolvelib.Resolver(provider, reporter)
847+
848+
result = rslvr.resolve([Requirement("testpkg")])
849+
assert "testpkg" in result.mapping
850+
851+
candidate = result.mapping["testpkg"]
852+
assert str(candidate.version) == "1.3"
853+
assert candidate.url == "https://example.com/pkg-1.3.tar.gz"
854+
855+
856+
def test_resolve_versionmap_no_match() -> None:
857+
from fromager.versionmap import VersionMap
858+
859+
version_map = VersionMap(
860+
{
861+
"1.2": "https://example.com/pkg-1.2.tar.gz",
862+
"1.3": "https://example.com/pkg-1.3.tar.gz",
863+
}
864+
)
865+
866+
provider = resolver.VersionMapProvider(
867+
version_map=version_map, package_name="testpkg"
868+
)
869+
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
870+
rslvr = resolvelib.Resolver(provider, reporter)
871+
872+
with pytest.raises(resolvelib.resolvers.ResolverException):
873+
rslvr.resolve([Requirement("testpkg>=2.0")])
874+
875+
795876
_gitlab_submodlib_repo_response = """
796877
[
797878
{

tests/test_versionmap.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,26 @@ def test_no_match() -> None:
100100
m.lookup(Requirement("pkg"), Requirement("pkg<1.0"))
101101
with pytest.raises(ValueError):
102102
m.lookup(Requirement("pkg>1.0"), Requirement("pkg<1.0"))
103+
104+
105+
def test_getitem() -> None:
106+
m = VersionMap(
107+
{
108+
"1.2": "value for 1.2",
109+
Version("1.3"): "value for 1.3",
110+
"1.0": "value for 1.0",
111+
}
112+
)
113+
# Access by Version object
114+
assert m[Version("1.2")] == "value for 1.2"
115+
assert m[Version("1.3")] == "value for 1.3"
116+
117+
# Access by string (auto-converted to Version)
118+
assert m["1.2"] == "value for 1.2"
119+
assert m["1.0"] == "value for 1.0"
120+
121+
# Non-existent version raises KeyError
122+
with pytest.raises(KeyError):
123+
m[Version("2.0")]
124+
with pytest.raises(KeyError):
125+
m["2.0"]

0 commit comments

Comments
 (0)