diff --git a/CLAUDE.md b/CLAUDE.md index 635e913c..19c888af 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -70,6 +70,35 @@ Expected usage patterns: - Additional minor intentional deviations may be documented directly in the codebase. Such intentional deviations should be marked with `REFERENCE PARITY` comments in the code. +## Review team + +After every change or milestone, or when explicitly prompted, dispatch several fresh-context review agents set to the +MAXIMUM THINKING EFFORT to review your work: + +- A subagent focusing on the FUNCTIONAL CORRECTNESS and ROBUSTNESS of the implementation. +- A subagent focusing on the ARCHITECTURAL CLEANLINESS, DESIGN PRACTICES, and CODE QUALITY. +- Distinct tools -- Codex, Antigravity (`agy`), Claude (check what's available, exclude yourself) -- + focusing on CORRECTNESS only. + +It is important that we use all available distinct tools to maximize the diversity of perspectives +and minimize blind spots. When all are done, review and consolidate their findings and act accordingly. +If behavioral defects are found, ensure extensive regression tests are introduced. + +Repeat the review/refine loop until the agents return only trivial feedback (or none) for three (sic!) consecutive turns. +Here, "trivial feedback" means stylistic/inconsequential issues such as wording, formatting, trivial parameter +validation, or anything else that does not materially affect the correctness or maintainability of the codebase. +Iteration until no feedback has been attempted in the past but it is not practical because in the absence of significant +issues the review agents tend to degrade to nitpicking. +Hence, we stop iteration earlier, as soon as the feedback ceases to contain significant findings. + +The requirement of multiple consecutive reviews with no significant findings is intended to improve the coverage. +We have seen in the past how a single review turn would come up blank while the next round (with zero code changes in +between) would dig up a critical defect. Hence, we repeat turns generously across distinct agents for maximum assurance. + +Review agents in maximum thinking mode may go silent for a long time; set a generous timeout (at least 1 hour). +Some agents expect input from stdin when launched headless and may get hung if no input is given; +in those cases consider redirecting from `/dev/null` or something like that; read the docs to figure out usage. + ## Documentation The documentation must be concise and to the point, with a strong focus on "how to use" rather than "how it works". diff --git a/src/pycyphal2/__init__.py b/src/pycyphal2/__init__.py index 1079b2bf..3881714b 100644 --- a/src/pycyphal2/__init__.py +++ b/src/pycyphal2/__init__.py @@ -155,7 +155,7 @@ async def main(): from ._transport import Transport as Transport from ._transport import TransportArrival as TransportArrival -__version__ = "2.0.0.dev3" +__version__ = "2.0.0.dev4" # pdoc needs __all__ to display re-exported members. __all__ = [ diff --git a/src/pycyphal2/_node.py b/src/pycyphal2/_node.py index b43c5db8..f3045223 100644 --- a/src/pycyphal2/_node.py +++ b/src/pycyphal2/_node.py @@ -2,6 +2,7 @@ import asyncio from collections import OrderedDict +from collections.abc import Coroutine import logging import math import os @@ -28,7 +29,7 @@ deserialize_header, ) from ._transport import SubjectWriter, Transport, TransportArrival -from ._api import Topic, Node, Publisher, Subscriber, Breadcrumb, Closable, Instant, Priority, SendError +from ._api import Topic, Node, Publisher, Subscriber, Breadcrumb, Closable, ClosedError, Instant, Priority, SendError from ._api import SUBJECT_ID_PINNED_MAX if TYPE_CHECKING: @@ -58,6 +59,22 @@ U64_MASK = (1 << 64) - 1 +def ack_is_last_attempt(current_ack_deadline_ns: int, current_ack_timeout: float, total_deadline_ns: int) -> bool: + """True if doubling the ACK timeout would overrun the total deadline, so this is the last retry.""" + next_ack_timeout_ns = round(current_ack_timeout * 2 * 1e9) + remaining_budget_ns = total_deadline_ns - current_ack_deadline_ns + return remaining_budget_ns < next_ack_timeout_ns + + +def ack_window(deadline_ns: int, ack_timeout: float) -> tuple[int, bool] | None: + """Next reliable-delivery ACK window: (ack_deadline_ns, is_last_attempt), or None if past the deadline.""" + now_ns = Instant.now().ns + if now_ns >= deadline_ns: + return None + ack_deadline_ns = min(deadline_ns, now_ns + round(ack_timeout * 1e9)) + return ack_deadline_ns, ack_is_last_attempt(ack_deadline_ns, ack_timeout, deadline_ns) + + class GossipScope(Enum): UNICAST = auto() BROADCAST = auto() @@ -115,6 +132,24 @@ def _name_is_homeful(name: str) -> bool: return name == "~" or name.startswith("~/") +def _is_valid_wire_name(name: str) -> bool: + """True if `name` is a well-formed *resolved* wire topic name, as required of names received in gossip: + nonempty, length-bounded, printable ASCII (33-126), already normalized (no leading/trailing/duplicate + '/'), verbatim (no '*'/'>' pattern tokens), not homeful ('~'/'~/...'), and pin-free (no '#' suffix). + The last two are stripped/expanded by resolve_name before a name reaches the wire, so their presence + means the gossip is unresolved/non-canonical and must not create a local topic.""" + return ( + bool(name) + and len(name) <= TOPIC_NAME_MAX + and "*" not in name + and ">" not in name + and not _name_is_homeful(name) + and _name_consume_pin_suffix(name)[1] is None + and all(33 <= ord(ch) <= 126 for ch in name) + and _name_normalize(name) == name + ) + + def resolve_name( name: str, home: str, namespace: str, remaps: dict[str, str] | None = None ) -> tuple[str, int | None, bool]: @@ -345,10 +380,8 @@ class PublishTracker: """Tracks a pending reliable publication awaiting ACKs.""" tag: int - deadline_ns: int ack_event: asyncio.Event acknowledged: bool = False - data: bytes | None = None ack_timeout: float = ACK_BASELINE_DEFAULT_TIMEOUT compromised: bool = False remaining: set[int] = field(default_factory=set) @@ -507,7 +540,6 @@ def __init__(self, transport: Transport, *, home: str, namespace: str) -> None: self._remaps: dict[str, str] = {} self._closed = False self.loop = asyncio.get_running_loop() - self._now_mono = time.monotonic() self._monitor_callbacks: dict[int, Callable[[Topic], None]] = {} self._next_monitor_callback_id = 0 @@ -573,7 +605,31 @@ def namespace(self) -> str: def transport(self) -> Transport: return self._transport + def _raise_if_closed(self) -> None: + if self._closed: + raise ClosedError(f"Node '{self._home}' is closed") + + def _spawn_detached(self, coro: Coroutine[Any, Any, None], what: str) -> None: + """Fire-and-forget a short-lived send: skip when closed, and never let its exception go unobserved.""" + if self._closed: + coro.close() + return + + def _done(task: asyncio.Task[None]) -> None: + if task.cancelled(): + return + ex = task.exception() + if ex is None: + return + if isinstance(ex, (SendError, OSError)): + _logger.debug("%s send failed: %s", what, ex) + else: + _logger.error("%s task crashed: %s", what, ex, exc_info=ex) + + self.loop.create_task(coro).add_done_callback(_done) + def remap(self, spec: str | dict[str, str]) -> None: + self._raise_if_closed() if isinstance(spec, str): spec = dict(x.split("=", 1) for x in spec.split() if "=" in x) assert isinstance(spec, dict) @@ -584,6 +640,7 @@ def remap(self, spec: str | dict[str, str]) -> None: def advertise(self, name: str) -> Publisher: from ._publisher import PublisherImpl + self._raise_if_closed() resolved, pin, verbatim = resolve_name(name, self._home, self._namespace, self._remaps) if not verbatim: raise ValueError("Cannot advertise on a pattern name") @@ -602,6 +659,7 @@ def advertise(self, name: str) -> Publisher: def subscribe(self, name: str, *, reordering_window: float | None = None) -> Subscriber: from ._subscriber import SubscriberImpl + self._raise_if_closed() resolved, pin, verbatim = resolve_name(name, self._home, self._namespace, self._remaps) if pin is not None and not verbatim: raise ValueError("Pattern names cannot be pinned") @@ -637,6 +695,7 @@ def subscribe(self, name: str, *, reordering_window: float | None = None) -> Sub return subscriber def monitor(self, callback: Callable[[Topic], None]) -> Closable: + self._raise_if_closed() callback_id = self._next_monitor_callback_id self._next_monitor_callback_id += 1 self._monitor_callbacks[callback_id] = callback @@ -653,6 +712,7 @@ def _notify_monitors(self, topic: Topic) -> None: _logger.exception("monitor() callback failed for %s", topic) async def scout(self, pattern: str) -> None: + self._raise_if_closed() resolved, pin, _ = resolve_name(pattern, self._home, self._namespace, self._remaps) if pin is not None: raise ValueError("Cannot scout a pinned name/pattern") @@ -800,12 +860,10 @@ def publish_tracker_release(topic: TopicImpl, tracker: PublishTracker) -> None: tracker.remaining.clear() @staticmethod - def prepare_publish_tracker(topic: TopicImpl, tag: int, deadline_ns: int, data: bytes) -> PublishTracker: + def prepare_publish_tracker(topic: TopicImpl, tag: int) -> PublishTracker: tracker = PublishTracker( tag=tag, - deadline_ns=deadline_ns, ack_event=asyncio.Event(), - data=data, ) tracker.ack_timeout = ACK_BASELINE_DEFAULT_TIMEOUT for assoc in sorted(topic.associations.values(), key=lambda x: x.remote_id): @@ -1036,14 +1094,6 @@ async def do_send() -> None: root.scout_task = self.loop.create_task(do_send()) - def send_scout(self, pattern: str) -> None: - """Send a scout message to discover topics matching a pattern.""" - - async def do_send() -> None: - await self._send_scout_once(pattern) - - self.loop.create_task(do_send()) - # -- Message Dispatch -- def on_subject_arrival(self, subject_id: int, arrival: TransportArrival) -> None: @@ -1055,6 +1105,8 @@ def on_unicast_arrival(self, arrival: TransportArrival) -> None: self.dispatch_arrival(arrival, subject_id=None, unicast=True) def dispatch_arrival(self, arrival: TransportArrival, *, subject_id: int | None, unicast: bool) -> None: + if self._closed: + return # Drop late arrivals after close instead of mutating state / spawning sends. msg = arrival.message if len(msg) < HEADER_SIZE: _logger.debug("Drop short msg len=%d", len(msg)) @@ -1207,7 +1259,7 @@ async def do_send() -> None: except (SendError, OSError) as e: _logger.debug("ACK send failed: %s", e) - self.loop.create_task(do_send()) + self._spawn_detached(do_send(), "ACK") def on_msg_ack(self, arrival: TransportArrival, hdr: MsgAckHeader | MsgNackHeader) -> None: topic = self.topics_by_hash.get(hdr.topic_hash) @@ -1295,7 +1347,7 @@ async def do_send() -> None: except (SendError, OSError) as e: _logger.debug("RSP ACK send failed: %s", e) - self.loop.create_task(do_send()) + self._spawn_detached(do_send(), "RSP ACK") def on_gossip( self, @@ -1306,6 +1358,8 @@ def on_gossip( ) -> None: name = "" if hdr.name_len > 0: + # Best-effort decode for diagnostics/monitoring; an invalid name cannot create a topic because + # topic_subscribe_if_matching validates the character set before creating one. name = payload[: hdr.name_len].decode("utf-8", errors="replace") topic = self.topics_by_hash.get(hdr.topic_hash) @@ -1374,6 +1428,14 @@ def topic_subscribe_if_matching( now: float, ) -> TopicImpl | None: """Create an implicit topic if any pattern subscriber matches the name.""" + # REFERENCE PARITY: the reference does not (yet) validate gossip-name characters here -- it trusts + # the hash. We additionally reject names that are not valid resolved wire names (non-normalized, + # non-verbatim by this implementation's rule, homeful, or pinned) so untrusted wire input cannot + # create a local topic with such a name. Consistent with the documented whitespace-strip deviation; + # the reference may adopt the same validation later. + if not _is_valid_wire_name(name): + _logger.debug("Gossip drop invalid wire name hash=%016x", topic_hash) + return None # Validate that the hash matches the name to prevent corrupt gossip from creating inconsistencies. if rapidhash(name) != topic_hash: _logger.debug("Gossip hash mismatch for '%s': got %016x, expected %016x", name, topic_hash, rapidhash(name)) @@ -1398,12 +1460,15 @@ def topic_subscribe_if_matching( def on_scout(self, arrival: TransportArrival, hdr: ScoutHeader, payload: bytes) -> None: if hdr.pattern_len == 0 or hdr.pattern_len > TOPIC_NAME_MAX or len(payload) < hdr.pattern_len: return + # Best-effort decode; an invalid pattern simply matches no local topic names. pattern = payload[: hdr.pattern_len].decode("utf-8", errors="replace") _logger.debug("Scout received pattern='%s' from %016x", pattern, arrival.remote_id) for topic in list(self.topics_by_name.values()): subs = match_pattern(pattern, topic.name) if subs is not None: - self.loop.create_task(self.send_gossip_unicast(topic, arrival.remote_id, arrival.priority)) + self._spawn_detached( + self.send_gossip_unicast(topic, arrival.remote_id, arrival.priority), "gossip unicast" + ) # -- Implicit Topic GC -- @@ -1474,6 +1539,12 @@ def close(self) -> None: return self._closed = True _logger.info("Node closing home='%s'", self._home) + # Unblock anything awaiting on a subscriber (`async for`): closing each enqueues StopAsyncIteration, + # otherwise a default (no-liveness-timeout) subscriber would wait on its queue forever. (Reliable + # publishes / response streams are deadline-bounded and resolve on their own.) + for root in list(self.sub_roots_verbatim.values()) + list(self.sub_roots_pattern.values()): + for sub in list(root.subscribers): + sub.close() self._gc_task.cancel() for root in list(self.sub_roots_pattern.values()): if root.scout_task is not None: diff --git a/src/pycyphal2/_publisher.py b/src/pycyphal2/_publisher.py index a10268ad..f1de71af 100644 --- a/src/pycyphal2/_publisher.py +++ b/src/pycyphal2/_publisher.py @@ -8,7 +8,7 @@ from ._api import DeliveryError, Instant, LivenessError, Priority, SendError from ._api import Publisher, Topic, ResponseStream, Response from ._header import MsgBeHeader, MsgRelHeader, RspBeHeader, RspRelHeader -from ._node import ACK_BASELINE_DEFAULT_TIMEOUT, NodeImpl, PublishTracker, SESSION_LIFETIME, TopicImpl +from ._node import ACK_BASELINE_DEFAULT_TIMEOUT, NodeImpl, PublishTracker, SESSION_LIFETIME, TopicImpl, ack_window from ._transport import TransportArrival _logger = logging.getLogger(__name__) @@ -124,7 +124,7 @@ async def request( ) self._topic.request_futures[tag] = stream - tracker = self._prepare_reliable_publish_tracker(tag, delivery_deadline.ns, payload) + tracker = self._prepare_reliable_publish_tracker(tag) try: initial_window = await self._reliable_publish_start(delivery_deadline, tag, payload, tracker) except asyncio.CancelledError: @@ -169,12 +169,6 @@ async def _request_publish( finally: self._release_reliable_publish_tracker(tag, tracker) - @staticmethod - def _ack_is_last_attempt(current_ack_deadline_ns: int, current_ack_timeout: float, total_deadline_ns: int) -> bool: - next_ack_timeout_ns = round(current_ack_timeout * 2 * 1e9) - remaining_budget_ns = total_deadline_ns - current_ack_deadline_ns - return remaining_budget_ns < next_ack_timeout_ns - @staticmethod def _ack_window_is_compromised(deadline_ns: int, current_ack_timeout: float) -> bool: return Instant.now().ns >= (deadline_ns - round(current_ack_timeout * 1e9)) @@ -189,16 +183,8 @@ def _serialize_message(self, tag: int, payload: bytes, *, reliable: bool) -> byt ) return hdr.serialize() + payload - @staticmethod - def _reliable_publish_window(deadline_ns: int, ack_timeout: float) -> tuple[int, bool] | None: - now_ns = Instant.now().ns - if now_ns >= deadline_ns: - return None - ack_deadline_ns = min(deadline_ns, now_ns + round(ack_timeout * 1e9)) - return ack_deadline_ns, PublisherImpl._ack_is_last_attempt(ack_deadline_ns, ack_timeout, deadline_ns) - - def _prepare_reliable_publish_tracker(self, tag: int, deadline_ns: int, payload: bytes) -> PublishTracker: - tracker = self._node.prepare_publish_tracker(self._topic, tag, deadline_ns, payload) + def _prepare_reliable_publish_tracker(self, tag: int) -> PublishTracker: + tracker = self._node.prepare_publish_tracker(self._topic, tag) tracker.ack_timeout = self.ack_timeout self._topic.publish_futures[tag] = tracker return tracker @@ -231,7 +217,7 @@ async def _reliable_publish_start( payload: bytes, tracker: PublishTracker, ) -> tuple[int, bool]: - initial_window = self._reliable_publish_window(deadline.ns, tracker.ack_timeout) + initial_window = ack_window(deadline.ns, tracker.ack_timeout) if initial_window is None: raise DeliveryError("Reliable publish not acknowledged before deadline") ack_deadline_ns, _ = initial_window @@ -277,7 +263,7 @@ async def _reliable_publish_continue( if last_attempt: break tracker.ack_timeout *= 2 - next_window = self._reliable_publish_window(deadline.ns, tracker.ack_timeout) + next_window = ack_window(deadline.ns, tracker.ack_timeout) if next_window is None: break ack_deadline_ns, last_attempt = next_window @@ -292,7 +278,7 @@ async def _reliable_publish_continue( raise DeliveryError("Reliable publish not acknowledged before deadline") async def _reliable_publish(self, deadline: Instant, tag: int, payload: bytes) -> None: - tracker = self._prepare_reliable_publish_tracker(tag, deadline.ns, payload) + tracker = self._prepare_reliable_publish_tracker(tag) try: initial_window = await self._reliable_publish_start(deadline, tag, payload, tracker) await self._reliable_publish_continue(deadline, tag, payload, tracker, initial_window) diff --git a/src/pycyphal2/_subscriber.py b/src/pycyphal2/_subscriber.py index 5099763e..ded55efb 100644 --- a/src/pycyphal2/_subscriber.py +++ b/src/pycyphal2/_subscriber.py @@ -15,6 +15,7 @@ NodeImpl, SubscriberRoot, TopicImpl, + ack_window, match_pattern, ) @@ -220,9 +221,6 @@ def on_timeout() -> None: state.timeout_handle = loop.call_later(delay, on_timeout) - def _arm_reorder_timeout(self, state: ReorderingState) -> None: - self._rearm_reorder_timeout(state) - def _drop_stale_reordering(self, now: float) -> None: stale = [key for key, state in self._reordering.items() if (state.last_active_at + SESSION_LIFETIME) < now] for key in stale: @@ -337,7 +335,7 @@ async def __call__( ack_timeout = ACK_BASELINE_DEFAULT_TIMEOUT * (1 << int(self._priority)) try: - initial_window = _ack_window(deadline.ns, ack_timeout) + initial_window = ack_window(deadline.ns, ack_timeout) if initial_window is None: raise DeliveryError("Reliable response not acknowledged before deadline") @@ -371,7 +369,7 @@ async def __call__( if last_attempt: break ack_timeout *= 2 - next_window = _ack_window(deadline.ns, ack_timeout) + next_window = ack_window(deadline.ns, ack_timeout) if next_window is None: break ack_deadline_ns, last_attempt = next_window @@ -414,17 +412,3 @@ def on_ack(self, positive: bool) -> None: self.done = True self.nacked = not positive self.ack_event.set() - - -def _ack_is_last_attempt(current_ack_deadline_ns: int, current_ack_timeout: float, total_deadline_ns: int) -> bool: - next_ack_timeout_ns = round(current_ack_timeout * 2 * 1e9) - remaining_budget_ns = total_deadline_ns - current_ack_deadline_ns - return remaining_budget_ns < next_ack_timeout_ns - - -def _ack_window(deadline_ns: int, ack_timeout: float) -> tuple[int, bool] | None: - now_ns = Instant.now().ns - if now_ns >= deadline_ns: - return None - ack_deadline_ns = min(deadline_ns, now_ns + round(ack_timeout * 1e9)) - return ack_deadline_ns, _ack_is_last_attempt(ack_deadline_ns, ack_timeout, deadline_ns) diff --git a/src/pycyphal2/can/_interface.py b/src/pycyphal2/can/_interface.py index 5833e689..1b7e6230 100644 --- a/src/pycyphal2/can/_interface.py +++ b/src/pycyphal2/can/_interface.py @@ -7,7 +7,8 @@ from .. import Closable, Instant -_CAN_EXT_ID_MASK = (1 << 29) - 1 +CAN_EXT_ID_MASK = (1 << 29) - 1 +CAN_STD_ID_MASK = (1 << 11) - 1 @dataclass(frozen=True) @@ -18,7 +19,7 @@ class Frame: data: bytes def __post_init__(self) -> None: - if not isinstance(self.id, int) or not (0 <= self.id <= _CAN_EXT_ID_MASK): + if not isinstance(self.id, int) or not (0 <= self.id <= CAN_EXT_ID_MASK): raise ValueError(f"Invalid CAN identifier: {self.id!r}") data = bytes(self.data) if len(data) > 64: @@ -39,9 +40,9 @@ class Filter: mask: int def __post_init__(self) -> None: - if not (0 <= self.id <= _CAN_EXT_ID_MASK): + if not (0 <= self.id <= CAN_EXT_ID_MASK): raise ValueError(f"Invalid CAN identifier: {self.id!r}") - if not (0 <= self.mask <= _CAN_EXT_ID_MASK): + if not (0 <= self.mask <= CAN_EXT_ID_MASK): raise ValueError(f"Invalid CAN mask: {self.mask!r}") @property diff --git a/src/pycyphal2/can/_media_slcan.py b/src/pycyphal2/can/_media_slcan.py index 16bbfae3..b9fd99ee 100644 --- a/src/pycyphal2/can/_media_slcan.py +++ b/src/pycyphal2/can/_media_slcan.py @@ -6,12 +6,11 @@ import logging -from ._interface import Frame -from ._wire import CAN_EXT_ID_MASK, DLC_TO_LENGTH, MTU_CAN_CLASSIC +from ._interface import CAN_EXT_ID_MASK, CAN_STD_ID_MASK, Frame +from ._wire import DLC_TO_LENGTH, MTU_CAN_CLASSIC _logger = logging.getLogger(__name__) -_CAN_STD_ID_MASK = (1 << 11) - 1 _CR = 0x0D # ACK / carriage return _LF = 0x0A _BEL = 0x07 # NACK / bell @@ -157,7 +156,7 @@ def _parse_data_frame(line: bytes, *, id_length: int, max_payload_length: int) - if len(line) < expected: _logger.debug("SLCAN drop data dlc mismatch len=%d expected=%d", len(line), expected) return None - if id_length == 3 and identifier > _CAN_STD_ID_MASK: + if id_length == 3 and identifier > CAN_STD_ID_MASK: _logger.debug("SLCAN drop invalid standard id=%x", identifier) return None data = _parse_hex_bytes(line[header_length:expected]) diff --git a/src/pycyphal2/can/_wire.py b/src/pycyphal2/can/_wire.py index 9a04e884..37be4b9f 100644 --- a/src/pycyphal2/can/_wire.py +++ b/src/pycyphal2/can/_wire.py @@ -10,9 +10,8 @@ CRC16CCITT_FALSE_RESIDUE, crc16ccitt_false_add, ) -from ._interface import Filter +from ._interface import CAN_EXT_ID_MASK, Filter -CAN_EXT_ID_MASK = (1 << 29) - 1 NODE_ID_MAX = 127 NODE_ID_ANONYMOUS = 0xFF NODE_ID_CAPACITY = NODE_ID_MAX + 1 diff --git a/src/pycyphal2/can/pythoncan.py b/src/pycyphal2/can/pythoncan.py index 1508f526..74941337 100644 --- a/src/pycyphal2/can/pythoncan.py +++ b/src/pycyphal2/can/pythoncan.py @@ -18,7 +18,7 @@ import threading from .._api import ClosedError, Instant -from ._interface import Filter, Interface, TimestampedFrame +from ._interface import CAN_EXT_ID_MASK, Filter, Interface, TimestampedFrame try: import can @@ -27,8 +27,11 @@ _logger = logging.getLogger(__name__) -_RX_POLL_TIMEOUT = 0.1 -_CAN_EXT_ID_MASK = (1 << 29) - 1 +# RX thread poll cadence. This also bounds how long filter()/close() may block the event loop while +# they quiesce the RX thread (it can only acknowledge a pause between recv() calls). Kept short so that +# admin stall stays small. ponytail: to remove the stall entirely, hand pending filters to the RX +# thread to apply between recv() calls instead of pausing it from the loop thread. +_RX_POLL_TIMEOUT = 0.02 class PythonCANInterface(Interface): @@ -96,6 +99,7 @@ def enqueue(self, id: int, data: Iterable[memoryview], deadline: Instant) -> Non self._raise_if_closed() if self._tx_task is None: self._tx_task = self._loop.create_task(self._tx_loop()) + self._tx_task.add_done_callback(self._on_task_done) for chunk in data: self._tx_seq += 1 self._tx_queue.put_nowait((id, self._tx_seq, deadline.ns, bytes(chunk))) @@ -151,15 +155,12 @@ async def _tx_loop(self) -> None: loop = asyncio.get_running_loop() while not self._closed: try: - identifier, _seq, deadline_ns, payload = await self._tx_queue.get() + identifier, seq, deadline_ns, payload = await self._tx_queue.get() except asyncio.CancelledError: raise if self._closed: return - if Instant.now().ns >= deadline_ns: - _logger.debug("PythonCAN tx drop expired iface=%s id=%08x", self._name, identifier) - continue - timeout = max(0.0, (deadline_ns - Instant.now().ns) * 1e-9) + timeout = (deadline_ns - Instant.now().ns) * 1e-9 if timeout <= 0.0: _logger.debug("PythonCAN tx drop expired iface=%s id=%08x", self._name, identifier) continue @@ -172,14 +173,10 @@ async def _tx_loop(self) -> None: ) try: await asyncio.wait_for(loop.run_in_executor(None, self._bus.send, msg, timeout), timeout=timeout) - except asyncio.TimeoutError: - self._tx_queue.put_nowait((identifier, self._tx_seq, deadline_ns, payload)) - self._tx_seq += 1 - await asyncio.sleep(0.001) - except can.CanError as ex: + except (asyncio.TimeoutError, can.CanError) as ex: + # Re-queue with the original seq so the frame keeps its place within its transfer. _logger.debug("PythonCAN tx retry iface=%s err=%s", self._name, ex) - self._tx_queue.put_nowait((identifier, self._tx_seq, deadline_ns, payload)) - self._tx_seq += 1 + self._tx_queue.put_nowait((identifier, seq, deadline_ns, payload)) await asyncio.sleep(0.001) except OSError as ex: self._fail(ex) @@ -229,6 +226,14 @@ def _fail(self, ex: BaseException) -> None: _logger.error("PythonCAN interface %s failed: %s", self._name, ex) self.close() + def _on_task_done(self, task: asyncio.Task[None]) -> None: + # Surface an unexpected TX-task crash as an interface failure instead of swallowing it. + if task.cancelled() or self._closed: + return + ex = task.exception() + if ex is not None: + self._fail(ex) + def _raise_if_closed(self) -> None: if self._closed: if self._failure is not None: @@ -258,4 +263,4 @@ def _parse_message(msg: can.Message) -> TimestampedFrame | None: if msg.is_remote_frame: _logger.debug("PythonCAN drop remote frame id=%08x", msg.arbitration_id) return None - return TimestampedFrame(id=msg.arbitration_id & _CAN_EXT_ID_MASK, data=bytes(msg.data), timestamp=Instant.now()) + return TimestampedFrame(id=msg.arbitration_id & CAN_EXT_ID_MASK, data=bytes(msg.data), timestamp=Instant.now()) diff --git a/src/pycyphal2/can/socketcan.py b/src/pycyphal2/can/socketcan.py index 738cc4e0..22dc842c 100644 --- a/src/pycyphal2/can/socketcan.py +++ b/src/pycyphal2/can/socketcan.py @@ -21,12 +21,14 @@ _CAN_FILTER_CAPACITY = 64 _CAN_INTERFACE_TYPE = 280 -_CAN_CLASSIC_MTU = 16 -_CAN_FD_MTU = 72 _CANFD_FDF = getattr(socket, "CANFD_FDF", 0) _CAN_FRAME_STRUCT = struct.Struct("=IB3x8s") _CANFD_FRAME_STRUCT = struct.Struct("=IBBBB64s") _CAN_FILTER_STRUCT = struct.Struct("=II") +# SocketCAN frame sizes — sizeof(struct can_frame)=16, sizeof(struct canfd_frame)=72 (NOT the 8/64 payload MTU). +# The kernel also reports these as the CAN netdev MTU, so they double as the FD-capability threshold. +_CLASSIC_FRAME_SIZE = _CAN_FRAME_STRUCT.size +_FD_FRAME_SIZE = _CANFD_FRAME_STRUCT.size _TRANSIENT_TX_ERRNO = {errno.EAGAIN, errno.EWOULDBLOCK, errno.ENOBUFS, errno.ENOMEM, errno.EBUSY} @@ -37,7 +39,7 @@ def __init__(self, name: str) -> None: self._sock.setblocking(False) self._sock.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK, 1) self._sock.bind((self._name,)) - self._fd = self._read_iface_mtu() >= _CAN_FD_MTU + self._fd = self._read_iface_mtu() >= _FD_FRAME_SIZE if self._fd: self._sock.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FD_FRAMES, 1) self._closed = False @@ -74,6 +76,7 @@ def enqueue(self, id: int, data: Iterable[memoryview], deadline: Instant) -> Non self._raise_if_closed() if self._tx_task is None: self._tx_task = asyncio.get_running_loop().create_task(self._tx_loop()) + self._tx_task.add_done_callback(self._on_task_done) for chunk in data: self._tx_seq += 1 self._tx_queue.put_nowait((id, self._tx_seq, deadline.ns, bytes(chunk))) @@ -94,7 +97,7 @@ def purge(self) -> None: async def receive(self) -> TimestampedFrame: self._raise_if_closed() loop = asyncio.get_running_loop() - recv_size = _CAN_FD_MTU if self._fd else _CAN_CLASSIC_MTU + recv_size = _FD_FRAME_SIZE if self._fd else _CLASSIC_FRAME_SIZE while True: try: raw = await loop.sock_recv(self._sock, recv_size) @@ -128,14 +131,17 @@ async def _tx_loop(self) -> None: raise if self._closed: return - if Instant.now().ns >= deadline_ns: - _logger.debug("SocketCAN tx drop expired iface=%s id=%08x", self._name, identifier) - continue - frame = self._encode(identifier, payload) - timeout = max(0.0, (deadline_ns - Instant.now().ns) * 1e-9) + timeout = (deadline_ns - Instant.now().ns) * 1e-9 if timeout <= 0.0: _logger.debug("SocketCAN tx drop expired iface=%s id=%08x", self._name, identifier) continue + try: + frame = self._encode(identifier, payload) + except ValueError as ex: + # An unencodable frame (e.g. oversized on a Classic-only interface) is a single bad + # frame, not an interface failure: drop it instead of letting it kill the TX task. + _logger.warning("SocketCAN tx drop unencodable iface=%s id=%08x: %s", self._name, identifier, ex) + continue try: await asyncio.wait_for(loop.sock_sendall(self._sock, frame), timeout=timeout) except asyncio.TimeoutError: @@ -159,6 +165,14 @@ def _fail(self, ex: BaseException) -> None: _logger.error("SocketCAN interface %s failed: %s", self._name, ex) self.close() + def _on_task_done(self, task: asyncio.Task[None]) -> None: + # Surface an unexpected TX-task crash as an interface failure instead of swallowing it. + if task.cancelled() or self._closed: + return + ex = task.exception() + if ex is not None: + self._fail(ex) + def _raise_if_closed(self) -> None: if self._closed: if self._failure is not None: @@ -172,7 +186,9 @@ def _is_transient_tx_error(ex: OSError) -> bool: def _encode(self, identifier: int, data: bytes) -> bytes: if len(data) > 8: if not self._fd: - raise ClosedError(f"SocketCAN interface {self._name} is not CAN FD-capable") + raise ValueError( + f"SocketCAN interface {self._name} cannot send a {len(data)}-byte frame on Classic CAN" + ) return _CANFD_FRAME_STRUCT.pack( socket.CAN_EFF_FLAG | (identifier & socket.CAN_EFF_MASK), len(data), @@ -189,14 +205,14 @@ def _encode(self, identifier: int, data: bytes) -> bytes: @staticmethod def _decode(raw: bytes) -> TimestampedFrame | None: - if len(raw) < _CAN_CLASSIC_MTU: + if len(raw) < _CLASSIC_FRAME_SIZE: _logger.debug("SocketCAN drop short len=%d", len(raw)) return None - if len(raw) >= _CAN_FD_MTU: - can_id, length, _flags, _reserved0, _reserved1, data = _CANFD_FRAME_STRUCT.unpack(raw[:_CAN_FD_MTU]) + if len(raw) >= _FD_FRAME_SIZE: + can_id, length, _flags, _reserved0, _reserved1, data = _CANFD_FRAME_STRUCT.unpack(raw[:_FD_FRAME_SIZE]) payload = data[: min(length, 64)] else: - can_id, length, data = _CAN_FRAME_STRUCT.unpack(raw[:_CAN_CLASSIC_MTU]) + can_id, length, data = _CAN_FRAME_STRUCT.unpack(raw[:_CLASSIC_FRAME_SIZE]) payload = data[: min(length, 8)] if (can_id & socket.CAN_EFF_FLAG) == 0 or (can_id & (socket.CAN_RTR_FLAG | socket.CAN_ERR_FLAG)) != 0: _logger.debug("SocketCAN drop non-extended or non-data id=%08x", can_id) diff --git a/src/pycyphal2/udp.py b/src/pycyphal2/udp.py index 5d0aa1bb..039b703f 100644 --- a/src/pycyphal2/udp.py +++ b/src/pycyphal2/udp.py @@ -172,7 +172,6 @@ def _frame_is_valid(header: _FrameHeader, payload_chunk: bytes | memoryview) -> class _Fragment: offset: int data: bytes - crc: int @property def end(self) -> int: @@ -210,7 +209,7 @@ def create(cls, header: _FrameHeader, timestamp_ns: int) -> _TransferSlot: ) def update(self, timestamp_ns: int, header: _FrameHeader, payload_chunk: bytes) -> bytes | None: - if self._accept_fragment(header.frame_payload_offset, payload_chunk, header.prefix_crc): + if self._accept_fragment(header.frame_payload_offset, payload_chunk): self.ts_max_ns = max(self.ts_max_ns, timestamp_ns) self.ts_min_ns = min(self.ts_min_ns, timestamp_ns) crc_end = header.frame_payload_offset + len(payload_chunk) @@ -221,7 +220,7 @@ def update(self, timestamp_ns: int, header: _FrameHeader, payload_chunk: bytes) return None return self._finalize_payload() - def _accept_fragment(self, offset: int, data: bytes, crc: int) -> bool: + def _accept_fragment(self, offset: int, data: bytes) -> bool: left = offset right = offset + len(data) for frag in self.fragments: @@ -244,7 +243,7 @@ def _accept_fragment(self, offset: int, data: bytes, crc: int) -> bool: v_left = min(left, left_neighbor.offset + 1) if left_neighbor is not None else left v_right = max(right, max(right_neighbor.end, 1) - 1) if right_neighbor is not None else right self.fragments = [frag for frag in self.fragments if not (frag.offset >= v_left and frag.end <= v_right)] - self.fragments.append(_Fragment(offset=offset, data=data, crc=crc)) + self.fragments.append(_Fragment(offset=offset, data=data)) self.fragments.sort(key=lambda frag: frag.offset) self.covered_prefix = self._compute_covered_prefix() return True @@ -474,10 +473,15 @@ class Interface: mtu_link: int """Link-layer MTU. E.g., 1500 for Ethernet, ~64K for loopback.""" + def __post_init__(self) -> None: + # Validate at construction (not via assert, which `python -O` strips) so a too-small MTU can + # never produce a negative mtu_cyphal that would make segmentation loop forever. + if self.mtu_link < _CYPHAL_MTU_LINK_MIN: + raise ValueError(f"mtu_link must be >= {_CYPHAL_MTU_LINK_MIN}, got {self.mtu_link}") + @property def mtu_cyphal(self) -> int: """Max Cyphal frame payload: mtu_link - 60 (IPv4 max) - 8 (UDP) - 32 (Cyphal header).""" - assert self.mtu_link >= _CYPHAL_MTU_LINK_MIN return self.mtu_link - _CYPHAL_OVERHEAD_MAX @@ -516,18 +520,22 @@ async def __call__(self, deadline: Instant, priority: Priority, message: bytes | except (OSError, SendError) as e: errors.append(e) + if errors and success_count == 0: + _logger.error("Send failed on all interfaces for subject %d", self._subject_id) + raise SendError("send failed on all interfaces") from ExceptionGroup( + "send failed on all interfaces", errors + ) if errors: - eg = ExceptionGroup("send failed on some interfaces", errors) - if success_count == 0: - _logger.error("Send failed on all interfaces for subject %d", self._subject_id) - raise SendError("send failed on all interfaces") from eg + # Redundant transport: delivery via at least one interface is a success. Warn but do not + # raise, otherwise the caller would treat a delivered transfer as failed and retry it, + # duplicating it on the interfaces that already succeeded. This mirrors the CAN transport + # (see _CANTransportImpl.send_transfer) and the reference cy_udp_posix push semantics. _logger.warning( - "Send failed on %d/%d interfaces for subject %d", - len(errors), + "Send succeeded on %d/%d interfaces for subject %d", + success_count, len(errors) + success_count, self._subject_id, ) - raise eg _logger.debug("Subject tx done sid=%d tid=%d", self._subject_id, transfer_id) @@ -843,8 +851,16 @@ async def unicast(self, deadline: Instant, priority: Priority, remote_id: int, m _logger.warning("No endpoint known for remote_id=0x%016x", remote_id) raise SendError("No endpoint known for remote_id") if errors: - raise ExceptionGroup("unicast send failed on some interfaces", errors) - _logger.debug("Unicast sent %d frames to remote_id=0x%016x", len(frames), remote_id) + # Redundant transport: delivery via at least one interface is a success. Warn but do not + # raise, otherwise a delivered transfer would be reported as failed and retried (mirrors + # _UDPSubjectWriter.__call__ and the reference cy_udp_posix unicast push semantics). + _logger.warning( + "Unicast succeeded on %d/%d interfaces for remote_id=0x%016x", + success_count, + len(errors) + success_count, + remote_id, + ) + _logger.debug("Unicast sent to remote_id=0x%016x", remote_id) def close(self) -> None: if self._closed: @@ -947,7 +963,11 @@ def _process_unicast_datagram( return if arrival is not None and self._unicast_handler is not None: _logger.debug("Unicast transfer complete from sender_uid=0x%016x", arrival.remote_id) - self._unicast_handler(arrival) + try: + self._unicast_handler(arrival) + except Exception: + # A raising handler must not kill the receive loop; drop the arrival and keep serving. + _logger.exception("Unicast handler raised iface=%d", iface_idx) def _process_subject_datagram( self, @@ -997,4 +1017,8 @@ def _process_subject_datagram( if arrival is not None: _logger.debug("Subject %d transfer complete from sender_uid=0x%016x", subject_id, arrival.remote_id) if handler is not None: - handler(arrival) + try: + handler(arrival) + except Exception: + # A raising handler must not kill the receive loop; drop the arrival and keep serving. + _logger.exception("Subject %d handler raised", subject_id) diff --git a/tests/can/test_interface.py b/tests/can/test_interface.py index 997d0635..8621a05a 100644 --- a/tests/can/test_interface.py +++ b/tests/can/test_interface.py @@ -6,7 +6,7 @@ from pycyphal2 import Instant from pycyphal2.can import Filter, Frame, TimestampedFrame -from pycyphal2.can._interface import _CAN_EXT_ID_MASK +from pycyphal2.can._interface import CAN_EXT_ID_MASK def test_frame_validation_and_normalization() -> None: @@ -20,7 +20,7 @@ def test_frame_validation_and_normalization() -> None: Frame(id="bad", data=b"") # type: ignore[arg-type] with pytest.raises(ValueError, match="Invalid CAN identifier"): - Frame(id=_CAN_EXT_ID_MASK + 1, data=b"") + Frame(id=CAN_EXT_ID_MASK + 1, data=b"") with pytest.raises(ValueError, match="Invalid CAN data length"): Frame(id=1, data=bytes(65)) @@ -35,7 +35,7 @@ def test_filter_validation_and_helpers() -> None: Filter(id=-1, mask=0) with pytest.raises(ValueError, match="Invalid CAN mask"): - Filter(id=0, mask=_CAN_EXT_ID_MASK + 1) + Filter(id=0, mask=CAN_EXT_ID_MASK + 1) with pytest.raises(ValueError, match="target number of filters must be positive"): Filter.coalesce([], 0) diff --git a/tests/can/test_pythoncan.py b/tests/can/test_pythoncan.py index 5b229695..544e191e 100644 --- a/tests/can/test_pythoncan.py +++ b/tests/can/test_pythoncan.py @@ -6,6 +6,7 @@ from pathlib import Path import sys import threading +import time from typing import cast from unittest.mock import MagicMock @@ -827,6 +828,44 @@ async def test_unit_tx_loop_multiple_deadline_drops() -> None: _close_all(a, b) +async def test_unit_tx_retry_preserves_intra_transfer_order() -> None: + """A frame that fails once and is retried keeps its place ahead of later frames of the transfer. + + Regression: the retry path used to re-queue with the global tx_seq counter, which could sort a + retried frame after its successors. The two payloads below sort opposite to their enqueue order + by byte value, so only a preserved seq keeps them in order after the first one is retried. + """ + sent: list[bytes] = [] + failed_once = threading.Event() + + class _RetryOnceBus: + channel_info = "retry:0" + + def recv(self, timeout: float | None = None) -> _can.Message | None: + if timeout: + time.sleep(min(timeout, 0.02)) + return None + + def send(self, msg: _can.Message, timeout: float | None = None) -> None: + data = bytes(msg.data) + if data == b"\x02" and not failed_once.is_set(): + failed_once.set() + raise _can.CanError("transient") + sent.append(data) + + def shutdown(self) -> None: + pass + + itf = PythonCANInterface(cast(_can.BusABC, _RetryOnceBus()), fd=False) + try: + itf.enqueue(0x123, [memoryview(b"\x02"), memoryview(b"\x01")], Instant.now() + 5.0) + await wait_for(lambda: len(sent) == 2, timeout=5.0) + assert failed_once.is_set() + assert sent == [b"\x02", b"\x01"] # Enqueue order preserved despite the retry of the first frame. + finally: + itf.close() + + async def test_unit_enqueue_after_purge_still_works() -> None: """After purge, new enqueue'd frames are still sent.""" a, b = _virtual_pair() diff --git a/tests/can/test_socketcan_unit.py b/tests/can/test_socketcan_unit.py index dea489f1..cf4b5d35 100644 --- a/tests/can/test_socketcan_unit.py +++ b/tests/can/test_socketcan_unit.py @@ -39,6 +39,9 @@ def __init__(self) -> None: def cancel(self) -> None: self.cancelled = True + def add_done_callback(self, _cb: object) -> None: + pass + class _FakeLoop: def __init__(self, *, recv: list[object] | None = None, send: list[object] | None = None) -> None: @@ -162,7 +165,7 @@ def test_socketcan_init_fd_and_classic_paths(monkeypatch: pytest.MonkeyPatch) -> fake_socket, created = _make_socket_module() module = _load_socketcan_module(monkeypatch, socket_module=fake_socket) - monkeypatch.setattr(module.SocketCANInterface, "_read_iface_mtu", lambda self: module._CAN_FD_MTU) + monkeypatch.setattr(module.SocketCANInterface, "_read_iface_mtu", lambda self: module._FD_FRAME_SIZE) fd_iface = module.SocketCANInterface("vcan0") fd_sock = created[-1] assert fd_iface.name == "vcan0" @@ -170,7 +173,7 @@ def test_socketcan_init_fd_and_classic_paths(monkeypatch: pytest.MonkeyPatch) -> assert ("setsockopt", fake_socket.SOL_CAN_RAW, fake_socket.CAN_RAW_FD_FRAMES, 1) in fd_sock.calls assert "vcan0" in repr(fd_iface) - monkeypatch.setattr(module.SocketCANInterface, "_read_iface_mtu", lambda self: module._CAN_CLASSIC_MTU) + monkeypatch.setattr(module.SocketCANInterface, "_read_iface_mtu", lambda self: module._CLASSIC_FRAME_SIZE) classic_iface = module.SocketCANInterface("vcan1") classic_sock = created[-1] assert classic_iface.fd is False @@ -277,15 +280,15 @@ def test_encode_and_decode_branches(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(module.Instant, "now", staticmethod(lambda: Instant(ns=123))) classic = _make_iface(module, fd=False) - with pytest.raises(ClosedError, match="not CAN FD-capable"): + with pytest.raises(ValueError, match="Classic CAN"): classic._encode(123, b"012345678") encoded_classic = classic._encode(123, b"abc") - assert len(encoded_classic) == module._CAN_CLASSIC_MTU + assert len(encoded_classic) == module._CLASSIC_FRAME_SIZE fd_iface = _make_iface(module, fd=True) encoded_fd = fd_iface._encode(456, b"012345678") - assert len(encoded_fd) == module._CAN_FD_MTU + assert len(encoded_fd) == module._FD_FRAME_SIZE assert module.SocketCANInterface._decode(b"\x00") is None @@ -481,3 +484,26 @@ async def wait_permanent(_coro: object, timeout: float) -> None: def _close_then_return(iface: object, item: tuple[int, int, int, bytes]) -> tuple[int, int, int, bytes]: iface._closed = True # type: ignore[attr-defined] return item + + +async def test_tx_loop_drops_oversized_classic_frame(monkeypatch: pytest.MonkeyPatch) -> None: + """An oversized frame on a Classic-CAN interface is dropped in-loop, not allowed to kill the TX task.""" + fake_socket, _ = _make_socket_module() + module = _load_socketcan_module(monkeypatch, socket_module=fake_socket) + + iface = _make_iface(module) # fd=False -> Classic CAN, so a >8-byte payload is unencodable. + iface._tx_queue = _QueueScript(iface, [(20, 1, 100, b"0" * 9), (21, 2, 100, b"abc")]) + loop = _FakeLoop() + monkeypatch.setattr(module.asyncio, "get_running_loop", lambda: loop) + monkeypatch.setattr(module.Instant, "now", staticmethod(lambda: Instant(ns=0))) + + async def wait_ok(coro: object, timeout: float) -> None: + del timeout + await cast(Awaitable[object], coro) + + monkeypatch.setattr(module.asyncio, "wait_for", wait_ok) + await iface._tx_loop() + + assert len(loop.sent_frames) == 1 # Oversized frame dropped; the following valid frame still sent. + assert iface._tx_queue.requeued == [] # The oversized frame was not re-queued. + assert iface._failure is None # And it was not treated as an interface failure. diff --git a/tests/can/test_wire_edges.py b/tests/can/test_wire_edges.py index d2dce532..d81e6ea7 100644 --- a/tests/can/test_wire_edges.py +++ b/tests/can/test_wire_edges.py @@ -3,8 +3,8 @@ import pytest from pycyphal2.can import Filter +from pycyphal2.can._interface import CAN_EXT_ID_MASK from pycyphal2.can._wire import ( - CAN_EXT_ID_MASK, CRC_INITIAL, CRC_RESIDUE, HEARTBEAT_SUBJECT_ID, diff --git a/tests/test_monitor.py b/tests/test_monitor.py index dc9b7849..d96cf34e 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -164,6 +164,33 @@ async def test_monitor_unknown_topic_preserves_decoded_wire_name(name_bytes: byt node.close() +async def test_gossip_wire_path_drops_malformed_name() -> None: + """End-to-end wire path: a gossip with a malformed name must not create a local topic, even with a + match-all pattern subscriber present (a valid name over the same path IS created -- the control).""" + node = new_node(MockTransport(node_id=1), home="n1") + node.subscribe(">") # Match-all, so only the name validation can prevent topic creation. + try: + bad = "foo bar" # A space is not a valid topic-name character. + _deliver_gossip( + node, + _make_gossip_arrival(topic_hash=rapidhash(bad), evictions=0, name_bytes=bad.encode()), + "broadcast", + topic_hash=rapidhash(bad), + ) + assert bad not in node.topics_by_name # Malformed name dropped on the wire path. + + good = "sensor/temp" + _deliver_gossip( + node, + _make_gossip_arrival(topic_hash=rapidhash(good), evictions=0, name_bytes=good.encode()), + "broadcast", + topic_hash=rapidhash(good), + ) + assert good in node.topics_by_name # Control: a valid name over the same path IS created. + finally: + node.close() + + async def test_monitor_is_not_invoked_for_inline_gossip_on_message_reception() -> None: node = new_node(MockTransport(node_id=1), home="n1") pub = node.advertise("/topic") diff --git a/tests/test_parity.py b/tests/test_parity.py index b29eb243..be45d755 100644 --- a/tests/test_parity.py +++ b/tests/test_parity.py @@ -95,7 +95,6 @@ async def test_association_slack_nack_capped(): tracker = PublishTracker( tag=tag, - deadline_ns=(pycyphal2.Instant.now() + 10.0).ns, remaining={42}, ack_event=asyncio.Event(), ) @@ -136,7 +135,6 @@ async def test_association_ack_resets_slack(): tracker = PublishTracker( tag=tag, - deadline_ns=(pycyphal2.Instant.now() + 10.0).ns, remaining={42}, ack_event=asyncio.Event(), ) diff --git a/tests/test_parity_coverage.py b/tests/test_parity_coverage.py index 9d21fa14..cd77ecc7 100644 --- a/tests/test_parity_coverage.py +++ b/tests/test_parity_coverage.py @@ -215,7 +215,7 @@ async def test_prepare_publish_tracker_skips_saturated_associations_and_release_ topic.associations = {10: live, 11: saturated} tag = topic.next_tag() - tracker = node.prepare_publish_tracker(topic, tag, (pycyphal2.Instant.now() + 1.0).ns, b"data") + tracker = node.prepare_publish_tracker(topic, tag) assert tracker.remaining == {10} assert tracker.associations == [live] @@ -243,7 +243,7 @@ async def test_publish_tracker_release_compromised_does_not_penalize_association assoc = Association(remote_id=10, last_seen=0.0, slack=ASSOC_SLACK_LIMIT - 1) topic.associations = {10: assoc} tag = topic.next_tag() - tracker = node.prepare_publish_tracker(topic, tag, (pycyphal2.Instant.now() + 1.0).ns, b"data") + tracker = node.prepare_publish_tracker(topic, tag) tracker.compromised = True node.publish_tracker_release(topic, tracker) @@ -266,7 +266,7 @@ async def test_reliable_publish_scheduler_lag_does_not_penalize_association() -> topic.associations = {10: assoc} tag = topic.next_tag() deadline = pycyphal2.Instant(ns=1_000_000_000) - tracker = pub._prepare_reliable_publish_tracker(tag, deadline.ns, b"data") + tracker = pub._prepare_reliable_publish_tracker(tag) tracker.ack_timeout = 0.2 now_ns = 0 diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index def06745..bfd35999 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -36,6 +36,66 @@ async def test_basic_best_effort_pubsub(): node.close() +async def test_node_operations_after_close_raise(): + """Public node operations reject use after close() instead of mutating a dead node.""" + net = MockNetwork() + tr = MockTransport(node_id=1, network=net) + node = new_node(tr, home="test_node") + node.close() + + with pytest.raises(pycyphal2.ClosedError): + node.advertise("my/topic") + with pytest.raises(pycyphal2.ClosedError): + node.subscribe("my/topic") + with pytest.raises(pycyphal2.ClosedError): + node.monitor(lambda _t: None) + with pytest.raises(pycyphal2.ClosedError): + node.remap("a=b") + with pytest.raises(pycyphal2.ClosedError): + await node.scout("pattern/*") + + +async def test_node_close_unblocks_pending_subscriber(): + """Closing the node ends a pending `async for` on a subscriber instead of hanging it forever.""" + net = MockNetwork() + tr = MockTransport(node_id=1, network=net) + node = new_node(tr, home="test_node") + sub = node.subscribe("my/topic") + task = asyncio.create_task(sub.__anext__()) + await asyncio.sleep(0) # Let the task start awaiting on the queue. + + node.close() + + with pytest.raises(StopAsyncIteration): + await asyncio.wait_for(task, timeout=1.0) + + +async def test_gossip_rejects_malformed_names(): + """A gossiped name that is not a normalized verbatim topic name must not become a local topic. + + A match-all '>' pattern subscriber is registered so that, WITHOUT the name guard, every name below + would be created -- this makes the test fail on the unfixed code rather than passing vacuously (the + hash matches the name in each case, so only the name guard can reject it). + """ + from pycyphal2._hash import rapidhash + + net = MockNetwork() + tr = MockTransport(node_id=1, network=net) + node = new_node(tr, home="test_node") + try: + node.subscribe(">") # Match-all pattern: the name guard is then the only thing that can reject. + # Sanity: a valid normalized name IS created through this subscriber, proving the path is live. + good = node.topic_subscribe_if_matching("sensor/temp", rapidhash("sensor/temp"), 0, 0, 0.0) + assert good is not None + + for bad_name in ["foo bar", "foo//bar", "/foo", "foo/", "foo/*", "foo/>", "foo#123", "~/foo", "~", ""]: + result = node.topic_subscribe_if_matching(bad_name, rapidhash(bad_name), 0, 0, 0.0) + assert result is None, bad_name + assert bad_name not in node.topics_by_name + finally: + node.close() + + async def test_publish_multiple_messages(): """Multiple messages should arrive in order.""" net = MockNetwork() diff --git a/tests/test_reliable.py b/tests/test_reliable.py index b0bb0ce5..87abae55 100644 --- a/tests/test_reliable.py +++ b/tests/test_reliable.py @@ -532,7 +532,6 @@ async def test_msg_ack_dispatch(): tag = topic.next_tag() tracker = PublishTracker( tag=tag, - deadline_ns=(pycyphal2.Instant.now() + 10.0).ns, remaining={42}, ack_event=asyncio.Event(), ) @@ -573,7 +572,6 @@ async def test_msg_nack_dispatch(): tag = topic.next_tag() tracker = PublishTracker( tag=tag, - deadline_ns=(pycyphal2.Instant.now() + 10.0).ns, remaining={42}, ack_event=asyncio.Event(), ) @@ -1000,7 +998,6 @@ async def test_multicast_msg_ack_ignored(): tag = topic.next_tag() tracker = PublishTracker( tag=tag, - deadline_ns=(pycyphal2.Instant.now() + 10.0).ns, remaining={42}, ack_event=asyncio.Event(), ) diff --git a/tests/test_udp.py b/tests/test_udp.py index 97c5e7b4..faefada3 100644 --- a/tests/test_udp.py +++ b/tests/test_udp.py @@ -475,13 +475,13 @@ def test_coverage_tracking(self): ), 0, ) - assert slot._accept_fragment(0, b"a" * 30, 0) + assert slot._accept_fragment(0, b"a" * 30) assert slot.covered_prefix == 30 - assert slot._accept_fragment(50, b"b" * 30, 0) + assert slot._accept_fragment(50, b"b" * 30) assert slot.covered_prefix == 30 - assert slot._accept_fragment(30, b"c" * 20, 0) + assert slot._accept_fragment(30, b"c" * 20) assert slot.covered_prefix == 80 - assert slot._accept_fragment(80, b"d" * 20, 0) + assert slot._accept_fragment(80, b"d" * 20) assert slot.covered_prefix == 100 def test_contained_fragment_rejected(self): @@ -491,8 +491,8 @@ def test_contained_fragment_rejected(self): ), 0, ) - assert slot._accept_fragment(0, b"A" * 4, 0) - assert not slot._accept_fragment(1, b"B" * 2, 0) + assert slot._accept_fragment(0, b"A" * 4) + assert not slot._accept_fragment(1, b"B" * 2) assert [(frag.offset, frag.data) for frag in slot.fragments] == [(0, b"AAAA")] def test_bridge_fragment_evicts_victim(self): @@ -502,10 +502,10 @@ def test_bridge_fragment_evicts_victim(self): ), 0, ) - assert slot._accept_fragment(0, b"AAAA", 0) - assert slot._accept_fragment(4, b"BB", 0) - assert slot._accept_fragment(6, b"CCCC", 0) - assert slot._accept_fragment(2, b"XXXXXX", 0) + assert slot._accept_fragment(0, b"AAAA") + assert slot._accept_fragment(4, b"BB") + assert slot._accept_fragment(6, b"CCCC") + assert slot._accept_fragment(2, b"XXXXXX") assert [(frag.offset, frag.data) for frag in slot.fragments] == [(0, b"AAAA"), (2, b"XXXXXX"), (6, b"CCCC")] def test_furthest_reaching_crc_is_used(self): @@ -1187,3 +1187,124 @@ async def mock_sock_sendto(s, data, addr): await t.async_sendto(sock, b"fail", ("127.0.0.1", 9999), deadline) finally: t.close() + + +@pytest.mark.asyncio +async def test_subject_send_succeeds_when_one_redundant_interface_fails() -> None: + """Redundant transport: a transfer delivered on >=1 interface is a success, not a raise. + + Regression: a partial-interface failure used to raise an ExceptionGroup, making the caller treat a + delivered transfer as failed and retry it, duplicating it on the interfaces that already succeeded. + """ + iface = Interface(address=IPv4Address("127.0.0.1"), mtu_link=1500) + pub = UDPTransport.new(interfaces=[iface, iface]) + assert isinstance(pub, _UDPTransportImpl) + try: + real_sendto = pub.async_sendto + + async def flaky_sendto(sock, data, addr, deadline): # type: ignore[no-untyped-def] + if sock is pub.tx_socks[0]: + raise OSError("interface 0 is down") + await real_sendto(sock, data, addr, deadline) + + with patch.object(pub, "async_sendto", flaky_sendto): + writer = pub.subject_advertise(10) + await writer(Instant.now() + 2.0, Priority.NOMINAL, b"redundant") # Must not raise. + + async def all_fail(sock, data, addr, deadline): # type: ignore[no-untyped-def] + raise OSError("all interfaces down") + + with patch.object(pub, "async_sendto", all_fail): + writer_all = pub.subject_advertise(11) + with pytest.raises(SendError): + await writer_all(Instant.now() + 2.0, Priority.NOMINAL, b"nope") + finally: + pub.close() + + +def test_interface_rejects_subminimum_mtu() -> None: + """A link MTU below the Cyphal minimum is rejected at construction, not via a strippable assert.""" + with pytest.raises(ValueError, match="mtu_link must be"): + Interface(address=IPv4Address("127.0.0.1"), mtu_link=100) + + +@pytest.mark.asyncio +async def test_raising_subject_handler_does_not_kill_rx_loop() -> None: + """A subject handler that raises must not tear down the receive loop; later transfers still arrive.""" + pub = UDPTransport.new_loopback() + sub = UDPTransport.new_loopback() + try: + received: list[bytes] = [] + + def handler(arrival: TransportArrival) -> None: + received.append(arrival.message) + raise RuntimeError("handler boom") # Raised on every delivery. + + sub.subject_listen(10, handler) + writer = pub.subject_advertise(10) + await writer(Instant.now() + 2.0, Priority.NOMINAL, b"first") + await writer(Instant.now() + 2.0, Priority.NOMINAL, b"second") + + for _ in range(200): + if len(received) >= 2: + break + await asyncio.sleep(0.01) + assert received == [b"first", b"second"] # Second arrived => loop survived the first raise. + finally: + pub.close() + sub.close() + + +@pytest.mark.asyncio +async def test_unicast_succeeds_when_one_redundant_interface_fails() -> None: + """Unicast, like subject send, must succeed if delivered on >=1 interface (used by reliable replies).""" + iface = Interface(address=IPv4Address("127.0.0.1"), mtu_link=1500) + t = UDPTransport.new(interfaces=[iface, iface]) + assert isinstance(t, _UDPTransportImpl) + try: + remote = 5 + # Pretend both interfaces have learned an endpoint for the remote node. + t._remote_endpoints[(remote, 0)] = ("127.0.0.1", t.tx_socks[0].getsockname()[1]) + t._remote_endpoints[(remote, 1)] = ("127.0.0.1", t.tx_socks[1].getsockname()[1]) + real_sendto = t.async_sendto + + async def flaky_sendto(sock, data, addr, deadline): # type: ignore[no-untyped-def] + if sock is t.tx_socks[0]: + raise OSError("interface 0 is down") + await real_sendto(sock, data, addr, deadline) + + with patch.object(t, "async_sendto", flaky_sendto): + await t.unicast(Instant.now() + 2.0, Priority.NOMINAL, remote, b"redundant") # Must not raise. + + async def all_fail(sock, data, addr, deadline): # type: ignore[no-untyped-def] + raise OSError("all interfaces down") + + with patch.object(t, "async_sendto", all_fail): + with pytest.raises(SendError): + await t.unicast(Instant.now() + 2.0, Priority.NOMINAL, remote, b"nope") + finally: + t.close() + + +@pytest.mark.asyncio +async def test_raising_unicast_handler_does_not_kill_rx() -> None: + """A raising unicast handler must not tear down the receive path; later transfers still arrive.""" + t = UDPTransport.new_loopback() + assert isinstance(t, _UDPTransportImpl) + try: + received: list[bytes] = [] + + def handler(arrival: TransportArrival) -> None: + received.append(arrival.message) + raise RuntimeError("handler boom") + + t.unicast_listen(handler) + uid = 0x1234 + first = _segment_transfer(Priority.NOMINAL, 0, uid, b"first", 1400)[0] + second = _segment_transfer(Priority.NOMINAL, 1, uid, b"second", 1400)[0] + # Feed two single-frame transfers through the same code path the RX loop uses. + t._process_unicast_datagram(first, "127.0.0.1", 40000, 0, Instant.now()) + t._process_unicast_datagram(second, "127.0.0.1", 40000, 0, Instant.now()) + assert received == [b"first", b"second"] # Second processed => the first raise was contained. + finally: + t.close()