Skip to content

Commit f2f722f

Browse files
committed
Code review feedback
1 parent 6b51c52 commit f2f722f

3 files changed

Lines changed: 212 additions & 29 deletions

File tree

pygit2/_libgit2/ffi.pyi

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,35 @@ class GitRepositoryC:
8787
# def _from_c(cls, ptr: 'GitRepositoryC', owned: bool) -> 'Repository': ...
8888
pass
8989

90+
class GitRemoteCallbacksC:
91+
# TODO: Several Anys need filling in
92+
version: int
93+
sideband_progress: Any
94+
completion: Any
95+
credentials: Any
96+
certificate_check: Any
97+
transfer_progress: Any
98+
update_tips: Any
99+
pack_progress: Any
100+
push_transfer_progress: Any
101+
push_update_reference: Any
102+
push_negotiation: Any
103+
transport: Any
104+
remote_ready: Any
105+
payload: Any
106+
resolve_url: Any
107+
update_refs: Any
108+
90109
class GitFetchOptionsC:
91110
# TODO: FetchOptions exist in _pygit2.pyi
92111
# incomplete
93112
depth: int
113+
callbacks: GitRemoteCallbacksC
94114
custom_headers: GitStrrayC
95115

96116
class GitPushOptionsC:
117+
# TODO incomplete
118+
callbacks: GitRemoteCallbacksC
97119
custom_headers: GitStrrayC
98120

99121
class GitSubmoduleC:
@@ -229,7 +251,11 @@ class GitRepositoryInitOptionsC:
229251
origin_url: ArrayC[char]
230252

231253
class GitCloneOptionsC:
232-
pass
254+
# TODO: Several Anys need filling in
255+
repository_cb: Any
256+
repository_cb_payload: Any
257+
remote_cb: Any
258+
remote_cb_payload: Any
233259

234260
class GitPackbuilderC:
235261
pass
@@ -260,6 +286,8 @@ def new(a: Literal['git_repository **']) -> _Pointer[GitRepositoryC]: ...
260286
@overload
261287
def new(a: Literal['git_remote **']) -> _Pointer[GitRemoteC]: ...
262288
@overload
289+
def new(a: Literal['git_remote_callbacks *']) -> GitRemoteCallbacksC: ...
290+
@overload
263291
def new(a: Literal['git_transaction **']) -> _Pointer[GitTransactionC]: ...
264292
@overload
265293
def new(a: Literal['git_repository_init_options *']) -> GitRepositoryInitOptionsC: ...
@@ -280,8 +308,12 @@ def new(a: Literal['git_blob **']) -> _Pointer[GitBlobC]: ...
280308
@overload
281309
def new(a: Literal['git_clone_options *']) -> GitCloneOptionsC: ...
282310
@overload
311+
def new(a: Literal['git_fetch_options *']) -> GitFetchOptionsC: ...
312+
@overload
283313
def new(a: Literal['git_merge_options *']) -> GitMergeOptionsC: ...
284314
@overload
315+
def new(a: Literal['git_push_options *']) -> GitPushOptionsC: ...
316+
@overload
285317
def new(a: Literal['git_blame_options *']) -> GitBlameOptionsC: ...
286318
@overload
287319
def new(a: Literal['git_annotated_commit **']) -> _Pointer[GitAnnotatedCommitC]: ...
@@ -368,6 +400,7 @@ def new(
368400
a: Literal['char *[]'], b: list[Any]
369401
) -> ArrayC[char_pointer]: ... # For string arrays
370402
def addressof(a: object, attribute: str) -> _Pointer[object]: ...
403+
def new_handle(a: T) -> _Pointer[T]: ...
371404

372405
class buffer(bytes):
373406
def __init__(self, a: object) -> None: ...

pygit2/callbacks.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,15 @@
7979
_Credentials = Username | UserPass | Keypair
8080

8181
if TYPE_CHECKING:
82-
from pygit2._libgit2.ffi import GitProxyOptionsC, GitStrrayC
82+
from pygit2._libgit2.ffi import (
83+
GitCloneOptionsC,
84+
GitFetchOptionsC,
85+
GitProxyOptionsC,
86+
GitPushOptionsC,
87+
GitStrrayC,
88+
)
8389

84-
from .remotes import PushUpdate, TransferProgress
90+
from .remotes import PushUpdate, Remote, TransferProgress
8591
#
8692
# The payload is the way to pass information from the pygit2 API, through
8793
# libgit2, to the Python callbacks. And back.
@@ -92,6 +98,9 @@ class Payload:
9298
repository: Callable | None
9399
remote: Callable | None
94100
clone_options: Any
101+
fetch_options: Any
102+
push_options: Any
103+
remote_callbacks: Any
95104

96105
def __init__(self, **kw: object) -> None:
97106
for key, value in kw.items():
@@ -124,8 +133,6 @@ class RemoteCallbacks(Payload):
124133
RemoteCallbacks(certificate_check=certificate_check).
125134
"""
126135

127-
push_options: Any
128-
129136
def __init__(
130137
self,
131138
credentials: _Credentials | None = None,
@@ -369,20 +376,18 @@ def git_custom_headers(
369376
payload: RemoteCallbacks,
370377
opts_custom_headers: Optional['GitStrrayC'] = None,
371378
) -> Generator[StrArray, Any, None]:
372-
custom_headers = payload.custom_headers()
373-
if custom_headers:
374-
with StrArray(custom_headers) as headers_array:
375-
if opts_custom_headers is not None:
376-
headers_array.assign_to(opts_custom_headers)
377-
yield headers_array
378-
379-
else:
380-
with StrArray(None) as null_array:
381-
yield null_array
379+
custom_headers = payload.custom_headers() or None
380+
with StrArray(custom_headers) as headers_array:
381+
if opts_custom_headers is not None:
382+
headers_array.assign_to(opts_custom_headers)
383+
yield headers_array
382384

383385

384386
@contextmanager
385-
def git_clone_options(payload, opts=None):
387+
def git_clone_options(
388+
payload: RemoteCallbacks,
389+
opts: Optional['GitCloneOptionsC'] = None,
390+
) -> Generator[RemoteCallbacks, Any, None]:
386391
if opts is None:
387392
opts = ffi.new('git_clone_options *')
388393
C.git_clone_options_init(opts, C.GIT_CLONE_OPTIONS_VERSION)
@@ -404,7 +409,10 @@ def git_clone_options(payload, opts=None):
404409

405410

406411
@contextmanager
407-
def git_fetch_options(payload, opts=None):
412+
def git_fetch_options(
413+
payload: RemoteCallbacks | None,
414+
opts: Optional['GitFetchOptionsC'] = None,
415+
) -> Generator[RemoteCallbacks, Any, None]:
408416
if payload is None:
409417
payload = RemoteCallbacks()
410418

@@ -431,7 +439,7 @@ def git_fetch_options(payload, opts=None):
431439

432440
@contextmanager
433441
def git_proxy_options(
434-
payload: object,
442+
payload: 'Remote | RemoteCallbacks',
435443
opts: Optional['GitProxyOptionsC'] = None,
436444
proxy: None | bool | str = None,
437445
) -> Generator['GitProxyOptionsC', None, None]:
@@ -445,20 +453,24 @@ def git_proxy_options(
445453
elif type(proxy) is str:
446454
opts.type = C.GIT_PROXY_SPECIFIED
447455
# Keep url in memory, otherwise memory is freed and bad things happen
448-
payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[attr-defined]
449-
opts.url = payload.__proxy_url # type: ignore[attr-defined]
456+
payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[union-attr]
457+
opts.url = payload.__proxy_url # type: ignore[union-attr]
450458
else:
451459
raise TypeError('Proxy must be None, True, or a string')
452460
yield opts
453461

454462

455463
@contextmanager
456-
def git_push_options(payload, opts=None):
464+
def git_push_options(
465+
payload: RemoteCallbacks | None,
466+
opts: Optional['GitPushOptionsC'] = None,
467+
) -> Generator[RemoteCallbacks, Any, None]:
457468
if payload is None:
458469
payload = RemoteCallbacks()
459470

460-
opts = ffi.new('git_push_options *')
461-
C.git_push_options_init(opts, C.GIT_PUSH_OPTIONS_VERSION)
471+
if opts is None:
472+
opts = ffi.new('git_push_options *')
473+
C.git_push_options_init(opts, C.GIT_PUSH_OPTIONS_VERSION)
462474

463475
# Plug callbacks
464476
opts.callbacks.sideband_progress = C._sideband_progress_cb
@@ -487,7 +499,9 @@ def git_push_options(payload, opts=None):
487499

488500

489501
@contextmanager
490-
def git_remote_callbacks(payload):
502+
def git_remote_callbacks(
503+
payload: RemoteCallbacks | None,
504+
) -> Generator[RemoteCallbacks, Any, None]:
491505
if payload is None:
492506
payload = RemoteCallbacks()
493507

test/test_remote.py

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
import pytest
3131

3232
import pygit2
33-
from pygit2 import Remote, Repository
33+
from pygit2 import Remote, RemoteCallbacks, Repository
34+
from pygit2.ffi import ffi
3435
from pygit2.remotes import PushUpdate, TransferProgress
3536

3637
from . import utils
@@ -485,8 +486,6 @@ def test_push_non_fast_forward_commits_to_remote_fails(
485486

486487

487488
def test_push_options(origin: Repository, clone: Repository, remote: Remote) -> None:
488-
from pygit2 import RemoteCallbacks
489-
490489
callbacks = RemoteCallbacks()
491490
remote.push(['refs/heads/master'], callbacks)
492491
remote_push_options = callbacks.push_options.remote_push_options
@@ -515,8 +514,6 @@ def test_push_options(origin: Repository, clone: Repository, remote: Remote) ->
515514

516515

517516
def test_push_threads(origin: Repository, clone: Repository, remote: Remote) -> None:
518-
from pygit2 import RemoteCallbacks
519-
520517
callbacks = RemoteCallbacks()
521518
remote.push(['refs/heads/master'], callbacks)
522519
assert callbacks.push_options.pb_parallelism == 1
@@ -562,3 +559,142 @@ def push_negotiation(self, updates: list[PushUpdate]) -> None:
562559
assert the_updates[0].dst == new_tip_id
563560

564561
assert origin.branches['master'].target == new_tip_id
562+
563+
564+
class HeaderCallbacks(RemoteCallbacks):
565+
def custom_headers(self) -> list[str] | None:
566+
return ['X-Other-One: foo', 'X-Other-Two: bar']
567+
568+
569+
def test_git_custom_headers_context_manager(
570+
origin: Repository,
571+
clone: Repository,
572+
remote: Remote,
573+
) -> None:
574+
from pygit2.callbacks import git_custom_headers, git_fetch_options, git_push_options
575+
576+
class EmptyHeaderCallbacks(RemoteCallbacks):
577+
def custom_headers(self) -> list[str] | None:
578+
return []
579+
580+
callbacks = RemoteCallbacks()
581+
with git_custom_headers(callbacks) as headers:
582+
assert headers.ptr == ffi.NULL
583+
584+
callbacks = EmptyHeaderCallbacks()
585+
with git_custom_headers(callbacks) as headers:
586+
assert headers.ptr == ffi.NULL
587+
588+
callbacks = HeaderCallbacks()
589+
with git_custom_headers(callbacks) as headers:
590+
ptr = headers.ptr
591+
assert ptr != ffi.NULL
592+
assert ptr.count == 2 # type: ignore[union-attr]
593+
assert ffi.string(ptr.strings[0]) == b'X-Other-One: foo' # type: ignore[union-attr,index]
594+
assert ffi.string(ptr.strings[1]) == b'X-Other-Two: bar' # type: ignore[union-attr,index]
595+
596+
callbacks = RemoteCallbacks()
597+
with git_fetch_options(callbacks) as payload:
598+
assert payload.fetch_options.custom_headers.count == 0
599+
assert payload.fetch_options.custom_headers.strings == ffi.NULL
600+
601+
callbacks = EmptyHeaderCallbacks()
602+
with git_fetch_options(callbacks) as payload:
603+
assert payload.fetch_options.custom_headers.count == 0
604+
assert payload.fetch_options.custom_headers.strings == ffi.NULL
605+
606+
callbacks = HeaderCallbacks()
607+
with git_fetch_options(callbacks) as payload:
608+
assert payload.fetch_options.custom_headers.count == 2
609+
assert (
610+
ffi.string(payload.fetch_options.custom_headers.strings[0])
611+
== b'X-Other-One: foo'
612+
)
613+
assert (
614+
ffi.string(payload.fetch_options.custom_headers.strings[1])
615+
== b'X-Other-Two: bar'
616+
)
617+
618+
callbacks = RemoteCallbacks()
619+
with git_push_options(callbacks) as payload:
620+
assert payload.push_options.custom_headers.count == 0
621+
assert payload.push_options.custom_headers.strings == ffi.NULL
622+
623+
callbacks = EmptyHeaderCallbacks()
624+
with git_push_options(callbacks) as payload:
625+
assert payload.push_options.custom_headers.count == 0
626+
assert payload.push_options.custom_headers.strings == ffi.NULL
627+
628+
callbacks = HeaderCallbacks()
629+
with git_push_options(callbacks) as payload:
630+
assert payload.push_options.custom_headers.count == 2
631+
assert (
632+
ffi.string(payload.push_options.custom_headers.strings[0])
633+
== b'X-Other-One: foo'
634+
)
635+
assert (
636+
ffi.string(payload.push_options.custom_headers.strings[1])
637+
== b'X-Other-Two: bar'
638+
)
639+
640+
641+
def test_push_headers(origin: Repository, clone: Repository, remote: Remote) -> None:
642+
callbacks = RemoteCallbacks()
643+
remote.push(['refs/heads/master'], callbacks=callbacks)
644+
assert callbacks.push_options.custom_headers.count == 0
645+
assert callbacks.push_options.custom_headers.strings == ffi.NULL
646+
647+
callbacks = HeaderCallbacks()
648+
remote.push(['refs/heads/master'], callbacks=callbacks)
649+
assert callbacks.push_options.custom_headers.count == 2
650+
assert callbacks.push_options.custom_headers.strings != ffi.NULL
651+
# strings pointed to by callbacks.push_options.custom_headers.strings[] are already freed
652+
653+
# make sure the custom headers don't "stick around"
654+
callbacks = RemoteCallbacks()
655+
remote.push(['refs/heads/master'], callbacks=callbacks)
656+
assert callbacks.push_options.custom_headers.count == 0
657+
assert callbacks.push_options.custom_headers.strings == ffi.NULL
658+
659+
660+
def test_fetch_headers(origin: Repository, clone: Repository, remote: Remote) -> None:
661+
callbacks = RemoteCallbacks()
662+
remote.fetch(['refs/heads/master'], callbacks=callbacks)
663+
assert callbacks.fetch_options.custom_headers.count == 0
664+
assert callbacks.fetch_options.custom_headers.strings == ffi.NULL
665+
666+
callbacks = HeaderCallbacks()
667+
remote.fetch(['refs/heads/master'], callbacks=callbacks)
668+
assert callbacks.fetch_options.custom_headers.count == 2
669+
assert callbacks.fetch_options.custom_headers.strings != ffi.NULL
670+
# strings pointed to by callbacks.fetch_options.custom_headers.strings[] are already freed
671+
672+
# make sure the custom headers don't "stick around"
673+
callbacks = RemoteCallbacks()
674+
remote.fetch(['refs/heads/master'], callbacks=callbacks)
675+
assert callbacks.fetch_options.custom_headers.count == 0
676+
assert callbacks.fetch_options.custom_headers.strings == ffi.NULL
677+
678+
679+
@utils.requires_network
680+
def test_connect_headers(testrepo: Repository) -> None:
681+
# This is just a check that having custom headers doesn't cause errors. As far as I can tell,
682+
# there's no way to assert that C.git_remote_connect was called with the headers except for
683+
# having a remote server that expects the headers and fails without them.
684+
685+
assert 1 == len(testrepo.remotes)
686+
remote = testrepo.remotes[0]
687+
688+
callbacks = RemoteCallbacks()
689+
remote.connect(callbacks=callbacks)
690+
refs = remote.list_heads(connect=False)
691+
assert refs
692+
# Check that a known ref is returned.
693+
assert next(iter(r for r in refs if r.name == 'refs/tags/v0.28.2'))
694+
695+
callbacks = HeaderCallbacks()
696+
remote.connect(callbacks=callbacks)
697+
refs = remote.list_heads(connect=False)
698+
assert refs
699+
# Check that a known ref is returned.
700+
assert next(iter(r for r in refs if r.name == 'refs/tags/v0.28.2'))

0 commit comments

Comments
 (0)