|
2 | 2 | import itertools |
3 | 3 | import pathlib |
4 | 4 | import shutil |
| 5 | +import textwrap |
5 | 6 | import typing |
6 | 7 | from unittest.mock import Mock, patch |
7 | 8 |
|
8 | 9 | import pytest |
| 10 | +from packaging.metadata import Metadata |
9 | 11 | from packaging.requirements import Requirement |
| 12 | +from packaging.utils import NormalizedName |
| 13 | +from packaging.version import Version |
10 | 14 |
|
11 | 15 | from fromager import build_environment, context, dependencies |
12 | 16 |
|
@@ -206,3 +210,72 @@ def test_get_build_sdist_dependencies_cached( |
206 | 210 | build_env=build_env, |
207 | 211 | ) |
208 | 212 | assert results == set([Requirement("foo==1.0")]) |
| 213 | + |
| 214 | + |
| 215 | +@patch("fromager.dependencies.pep517_metadata_of_sdist") |
| 216 | +def test_default_get_install_dependencies_of_sdist( |
| 217 | + m_pep517_metadata_of_sdist: Mock, |
| 218 | + tmp_context: context.WorkContext, |
| 219 | + tmp_path: pathlib.Path, |
| 220 | +) -> None: |
| 221 | + req = Requirement("huggingface-hub") |
| 222 | + version = Version("1.2.3") |
| 223 | + # sdist metadata name may not be normalized |
| 224 | + metadata_txt = textwrap.dedent( |
| 225 | + """\ |
| 226 | + Metadata-Version: 2.3 |
| 227 | + Name: HuggingFace_Hub |
| 228 | + Version: 1.2.3 |
| 229 | + Requires-Dist: filelock |
| 230 | + Requires-Dist: requests |
| 231 | + """ |
| 232 | + ) |
| 233 | + metadata = Metadata.from_email(metadata_txt) |
| 234 | + m_pep517_metadata_of_sdist.return_value = metadata |
| 235 | + |
| 236 | + requirements = dependencies.default_get_install_dependencies_of_sdist( |
| 237 | + ctx=tmp_context, |
| 238 | + req=req, |
| 239 | + version=version, |
| 240 | + sdist_root_dir=tmp_path, |
| 241 | + build_env=Mock(), |
| 242 | + extra_environ={}, |
| 243 | + build_dir=tmp_path, |
| 244 | + config_settings={}, |
| 245 | + ) |
| 246 | + assert requirements == {Requirement("filelock"), Requirement("requests")} |
| 247 | + |
| 248 | + |
| 249 | +@pytest.mark.parametrize( |
| 250 | + "req_str,version_str,dist_name_str,dist_version_str,exc", |
| 251 | + [ |
| 252 | + ("mypkg", "1.0", "mypkg", "1.0", None), |
| 253 | + ("MyPKG", "1.0", "mypkg", "1.0", None), |
| 254 | + ("mypkg", "1.0", "MyPKG", "1.0", RuntimeError), |
| 255 | + ("mypkg", "1.0", "otherpkg", "1.0", ValueError), |
| 256 | + ("mypkg", "1.0", "mypkg", "1.1", ValueError), |
| 257 | + ("mypkg", "1.0+local", "mypkg", "1.0+local", None), |
| 258 | + ("mypkg", "1.0", "mypkg", "1.0+local", None), |
| 259 | + ("mypkg", "1.0+local", "mypkg", "1.0", None), |
| 260 | + ], |
| 261 | +) |
| 262 | +def test_validate_dist_name_version( |
| 263 | + req_str: str, |
| 264 | + version_str: str, |
| 265 | + dist_name_str: str, |
| 266 | + dist_version_str: str, |
| 267 | + exc: type[Exception] | None, |
| 268 | +) -> None: |
| 269 | + validate = functools.partial( |
| 270 | + dependencies.validate_dist_name_version, |
| 271 | + req=Requirement(req_str), |
| 272 | + version=Version(version_str), |
| 273 | + what="test", |
| 274 | + dist_name=typing.cast(NormalizedName, dist_name_str), |
| 275 | + dist_version=Version(dist_version_str), |
| 276 | + ) |
| 277 | + if exc is None: |
| 278 | + validate() |
| 279 | + else: |
| 280 | + with pytest.raises(exc): |
| 281 | + validate() |
0 commit comments