Skip to content

Commit f0b83d3

Browse files
committed
cli/sync(refactor[typing]): Type repo update payloads
why: Remove remaining Any usages while keeping repo sync behavior unchanged. what: - Add typed payload helpers for libvcs create_project calls - Tighten PrivatePath constructor argument types
1 parent 89f3f7b commit f0b83d3

2 files changed

Lines changed: 67 additions & 17 deletions

File tree

src/vcspull/_internal/private_path.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ class PrivatePath(PrivatePathBase):
3434
'~/notes.txt'
3535
"""
3636

37-
def __new__(cls, *args: t.Any, **kwargs: t.Any) -> PrivatePath:
37+
def __new__(
38+
cls,
39+
*args: str | os.PathLike[str],
40+
**kwargs: object,
41+
) -> PrivatePath:
3842
return super().__new__(cls, *args, **kwargs)
3943

4044
@classmethod

src/vcspull/cli/sync.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
import subprocess
1414
import sys
1515
import typing as t
16-
from collections.abc import Callable
1716
from copy import deepcopy
1817
from dataclasses import dataclass
1918
from datetime import datetime
2019
from io import StringIO
2120
from time import perf_counter
2221

22+
from libvcs._internal.run import ProgressCallbackProtocol
2323
from libvcs._internal.shortcuts import create_project
2424
from libvcs._internal.types import VCSLiteral
2525
from libvcs.sync.git import GitSync
@@ -48,7 +48,42 @@
4848

4949
log = logging.getLogger(__name__)
5050

51-
ProgressCallback = Callable[[str, datetime], None]
51+
ProgressCallback: t.TypeAlias = ProgressCallbackProtocol
52+
53+
54+
class RepoPayloadBase(t.TypedDict):
55+
"""Keyword arguments used to create a repo via libvcs."""
56+
57+
url: str
58+
path: str | os.PathLike[str]
59+
progress_callback: ProgressCallback | None
60+
61+
62+
class GitRepoPayload(RepoPayloadBase):
63+
"""Keyword arguments for git repositories."""
64+
65+
vcs: t.Literal["git"]
66+
67+
68+
class HgRepoPayload(RepoPayloadBase):
69+
"""Keyword arguments for Mercurial repositories."""
70+
71+
vcs: t.Literal["hg"]
72+
73+
74+
class SvnRepoPayload(RepoPayloadBase):
75+
"""Keyword arguments for Subversion repositories."""
76+
77+
vcs: t.Literal["svn"]
78+
79+
80+
class RepoPayload(t.TypedDict):
81+
"""Keyword arguments used to create a repo via libvcs."""
82+
83+
url: str
84+
path: str | os.PathLike[str]
85+
vcs: VCSLiteral | None
86+
progress_callback: ProgressCallback | None
5287

5388

5489
PLAN_SYMBOLS: dict[PlanAction, str] = {
@@ -836,28 +871,39 @@ def __init__(self, repo_url: str) -> None:
836871

837872

838873
def update_repo(
839-
repo_dict: t.Any,
874+
repo_dict: ConfigDict,
840875
progress_callback: ProgressCallback | None = None,
841876
# repo_dict: Dict[str, Union[str, Dict[str, GitRemote], pathlib.Path]]
842877
) -> GitSync:
843878
"""Synchronize a single repository."""
844-
repo_dict = deepcopy(repo_dict)
845-
if "pip_url" not in repo_dict:
846-
repo_dict["pip_url"] = repo_dict.pop("url")
847-
if "url" not in repo_dict:
848-
repo_dict["url"] = repo_dict.pop("pip_url")
879+
repo_payload = t.cast("dict[str, object]", deepcopy(repo_dict))
880+
if "pip_url" not in repo_payload:
881+
repo_payload["pip_url"] = repo_payload.pop("url")
882+
if "url" not in repo_payload:
883+
repo_payload["url"] = repo_payload.pop("pip_url")
884+
885+
repo_payload["progress_callback"] = progress_callback or progress_cb
886+
887+
repo_url = t.cast("str", repo_payload["url"])
888+
repo_vcs = t.cast("VCSLiteral | None", repo_payload.get("vcs"))
889+
if repo_vcs is None:
890+
vcs = guess_vcs(url=repo_url)
891+
if vcs is None:
892+
raise CouldNotGuessVCSFromURL(repo_url=repo_url)
849893

850-
repo_dict["progress_callback"] = progress_callback or progress_cb
894+
repo_payload["vcs"] = vcs
895+
repo_vcs = vcs
851896

852-
if repo_dict.get("vcs") is None:
853-
vcs = guess_vcs(url=repo_dict["url"])
854-
if vcs is None:
855-
raise CouldNotGuessVCSFromURL(repo_url=repo_dict["url"])
897+
assert repo_vcs is not None
856898

857-
repo_dict["vcs"] = vcs
899+
if repo_vcs == "git":
900+
r = create_project(**t.cast("GitRepoPayload", repo_payload))
901+
elif repo_vcs == "svn":
902+
r = t.cast("GitSync", create_project(**t.cast("SvnRepoPayload", repo_payload)))
903+
else:
904+
r = t.cast("GitSync", create_project(**t.cast("HgRepoPayload", repo_payload)))
858905

859-
r = create_project(**repo_dict) # Creates the repo object
860906
r.update_repo(set_remotes=True) # Creates repo if not exists and fetches
861907

862908
# TODO: Fix this
863-
return r # type:ignore
909+
return r

0 commit comments

Comments
 (0)