Skip to content

Commit 64aee9d

Browse files
authored
Merge pull request #784 from tiran/normalize-sdist-metadata
fix: normalize name in default_get_install_dependencies_of_sdist
2 parents 534d40d + bb9ee8e commit 64aee9d

2 files changed

Lines changed: 95 additions & 8 deletions

File tree

src/fromager/dependencies.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tomlkit
1313
from packaging.metadata import Metadata
1414
from packaging.requirements import Requirement
15-
from packaging.utils import canonicalize_name
15+
from packaging.utils import NormalizedName, canonicalize_name
1616
from packaging.version import Version
1717

1818
from . import (
@@ -318,7 +318,8 @@ def default_get_install_dependencies_of_sdist(
318318
req=req,
319319
version=version,
320320
what="sdist metadata",
321-
dist_name=metadata.name,
321+
# Metadata name is a non-normalized string
322+
dist_name=canonicalize_name(metadata.name),
322323
dist_version=metadata.version,
323324
)
324325
if not metadata.requires_dist:
@@ -371,17 +372,30 @@ def pep517_metadata_of_sdist(
371372

372373

373374
def validate_dist_name_version(
374-
req: Requirement, version: Version, what: str, dist_name: str, dist_version: Version
375+
req: Requirement,
376+
version: Version,
377+
what: str,
378+
dist_name: NormalizedName,
379+
dist_version: Version,
375380
) -> None:
376381
"""Validate that dist name and version matches expected values"""
377382
req_name = canonicalize_name(req.name)
378383
if dist_name != req_name:
379-
raise ValueError(f"{what} does not match requirement {req_name!r}")
380-
if dist_version != version:
381-
if dist_version.public != version.public:
382-
raise ValueError(f"{what} does not match public version {version!r}")
384+
if dist_name != canonicalize_name(dist_name):
385+
# API misuse
386+
raise RuntimeError("dist_name argument {dist_name!r} is not normalized")
383387
else:
384-
logger.warning(f"{what} has different local version than {version!r}")
388+
raise ValueError(
389+
f"{what} {dist_name!r} does not match requirement {req_name!r}"
390+
)
391+
if dist_version.public != version.public:
392+
raise ValueError(
393+
f"{what} {dist_version.public!r} does not match public version {version.public!r}"
394+
)
395+
if dist_version.local != version.local:
396+
logger.warning(
397+
f"{what} {dist_version.local!r} has different local version than {version.local!r}"
398+
)
385399

386400

387401
def get_install_dependencies_of_wheel(

tests/test_dependencies.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
import itertools
33
import pathlib
44
import shutil
5+
import textwrap
56
import typing
67
from unittest.mock import Mock, patch
78

89
import pytest
10+
from packaging.metadata import Metadata
911
from packaging.requirements import Requirement
12+
from packaging.utils import NormalizedName
13+
from packaging.version import Version
1014

1115
from fromager import build_environment, context, dependencies
1216

@@ -206,3 +210,72 @@ def test_get_build_sdist_dependencies_cached(
206210
build_env=build_env,
207211
)
208212
assert results == set([Requirement("foo==1.0")])
213+
214+
215+
@patch("fromager.dependencies.pep517_metadata_of_sdist")
216+
def test_default_get_install_dependencies_of_sdist(
217+
m_pep517_metadata_of_sdist: Mock,
218+
tmp_context: context.WorkContext,
219+
tmp_path: pathlib.Path,
220+
) -> None:
221+
req = Requirement("huggingface-hub")
222+
version = Version("1.2.3")
223+
# sdist metadata name may not be normalized
224+
metadata_txt = textwrap.dedent(
225+
"""\
226+
Metadata-Version: 2.3
227+
Name: HuggingFace_Hub
228+
Version: 1.2.3
229+
Requires-Dist: filelock
230+
Requires-Dist: requests
231+
"""
232+
)
233+
metadata = Metadata.from_email(metadata_txt)
234+
m_pep517_metadata_of_sdist.return_value = metadata
235+
236+
requirements = dependencies.default_get_install_dependencies_of_sdist(
237+
ctx=tmp_context,
238+
req=req,
239+
version=version,
240+
sdist_root_dir=tmp_path,
241+
build_env=Mock(),
242+
extra_environ={},
243+
build_dir=tmp_path,
244+
config_settings={},
245+
)
246+
assert requirements == {Requirement("filelock"), Requirement("requests")}
247+
248+
249+
@pytest.mark.parametrize(
250+
"req_str,version_str,dist_name_str,dist_version_str,exc",
251+
[
252+
("mypkg", "1.0", "mypkg", "1.0", None),
253+
("MyPKG", "1.0", "mypkg", "1.0", None),
254+
("mypkg", "1.0", "MyPKG", "1.0", RuntimeError),
255+
("mypkg", "1.0", "otherpkg", "1.0", ValueError),
256+
("mypkg", "1.0", "mypkg", "1.1", ValueError),
257+
("mypkg", "1.0+local", "mypkg", "1.0+local", None),
258+
("mypkg", "1.0", "mypkg", "1.0+local", None),
259+
("mypkg", "1.0+local", "mypkg", "1.0", None),
260+
],
261+
)
262+
def test_validate_dist_name_version(
263+
req_str: str,
264+
version_str: str,
265+
dist_name_str: str,
266+
dist_version_str: str,
267+
exc: type[Exception] | None,
268+
) -> None:
269+
validate = functools.partial(
270+
dependencies.validate_dist_name_version,
271+
req=Requirement(req_str),
272+
version=Version(version_str),
273+
what="test",
274+
dist_name=typing.cast(NormalizedName, dist_name_str),
275+
dist_version=Version(dist_version_str),
276+
)
277+
if exc is None:
278+
validate()
279+
else:
280+
with pytest.raises(exc):
281+
validate()

0 commit comments

Comments
 (0)