diff --git a/canopen/emcy.py b/canopen/emcy.py index 22d1eba8..8d075008 100644 --- a/canopen/emcy.py +++ b/canopen/emcy.py @@ -39,7 +39,10 @@ def on_emcy(self, can_id, data, timestamp): self.emcy_received.notify_all() for callback in self.callbacks: - callback(entry) + try: + callback(entry) + except Exception: + logger.exception("Exception in EMCY callback") def add_callback(self, callback: Callable[[EmcyError], None]): """Get notified on EMCY messages from this node. diff --git a/test/test_emcy.py b/test/test_emcy.py index d883e9c8..b4b54a19 100644 --- a/test/test_emcy.py +++ b/test/test_emcy.py @@ -6,18 +6,26 @@ import can import canopen -from canopen.emcy import EmcyError TIMEOUT = 0.1 +@contextmanager +def mock_rx_thread(consumer: canopen.emcy.EmcyConsumer, func): + t = threading.Thread(target=func) + try: + with consumer.emcy_received: + t.start() + yield t + finally: + t.join(TIMEOUT) + + class TestEmcy(unittest.TestCase): - def setUp(self): - self.emcy = canopen.emcy.EmcyConsumer() def check_error(self, err, code, reg, data, ts): - self.assertIsInstance(err, EmcyError) + self.assertIsInstance(err, canopen.emcy.EmcyError) self.assertIsInstance(err, Exception) self.assertEqual(err.code, code) self.assertEqual(err.register, reg) @@ -25,59 +33,58 @@ def check_error(self, err, code, reg, data, ts): self.assertAlmostEqual(err.timestamp, ts) def test_emcy_consumer_on_emcy(self): - # Make sure multiple callbacks receive the same information. + """Make sure multiple callbacks receive the same information.""" + emcy = canopen.emcy.EmcyConsumer() acc1 = [] acc2 = [] - self.emcy.add_callback(lambda err: acc1.append(err)) - self.emcy.add_callback(lambda err: acc2.append(err)) + emcy.add_callback(lambda err: acc1.append(err)) + emcy.add_callback(lambda err: acc2.append(err)) - # Dispatch an EMCY datagram. - self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) - self.assertEqual(len(self.emcy.log), 1) - self.assertEqual(len(self.emcy.active), 1) + self.assertEqual(len(emcy.log), 1) + self.assertEqual(len(emcy.active), 1) - error = self.emcy.log[0] - self.assertEqual(self.emcy.active[0], error) + error = emcy.log[0] + self.assertEqual(emcy.active[0], error) for err in error, acc1[0], acc2[0]: self.check_error( error, code=0x2001, reg=0x02, data=bytes([0, 1, 2, 3, 4]), ts=1000, ) - # Dispatch a new EMCY datagram. - self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) - self.assertEqual(len(self.emcy.log), 2) - self.assertEqual(len(self.emcy.active), 2) + emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) + self.assertEqual(len(emcy.log), 2) + self.assertEqual(len(emcy.active), 2) - error = self.emcy.log[1] - self.assertEqual(self.emcy.active[1], error) + error = emcy.log[1] + self.assertEqual(emcy.active[1], error) for err in error, acc1[1], acc2[1]: self.check_error( error, code=0x9010, reg=0x01, data=bytes([4, 3, 2, 1, 0]), ts=2000, ) - # Dispatch an EMCY reset. - self.emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 2000) - self.assertEqual(len(self.emcy.log), 3) - self.assertEqual(len(self.emcy.active), 0) + emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 2000) + self.assertEqual(len(emcy.log), 3) + self.assertEqual(len(emcy.active), 0) def test_emcy_consumer_reset(self): - self.emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) - self.emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) - self.assertEqual(len(self.emcy.log), 2) - self.assertEqual(len(self.emcy.active), 2) + emcy = canopen.emcy.EmcyConsumer() + emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) + self.assertEqual(len(emcy.log), 2) + self.assertEqual(len(emcy.active), 2) - self.emcy.reset() - self.assertEqual(len(self.emcy.log), 0) - self.assertEqual(len(self.emcy.active), 0) + emcy.reset() + self.assertEqual(len(emcy.log), 0) + self.assertEqual(len(emcy.active), 0) def test_emcy_consumer_wait(self): - PAUSE = TIMEOUT / 2 + emcy = canopen.emcy.EmcyConsumer() def push_err(): - self.emcy.on_emcy(0x81, b'\x01\x20\x01\x01\x02\x03\x04\x05', 100) + emcy.on_emcy(0x81, b'\x01\x20\x01\x01\x02\x03\x04\x05', 100) def check_err(err): self.assertIsNotNone(err) @@ -86,47 +93,98 @@ def check_err(err): data=bytes([1, 2, 3, 4, 5]), ts=100, ) - @contextmanager - def timer(func): - t = threading.Timer(PAUSE, func) - try: - yield t - finally: - t.join(TIMEOUT) - # Check unfiltered wait, on timeout. - self.assertIsNone(self.emcy.wait(timeout=TIMEOUT)) + self.assertIsNone(emcy.wait(timeout=TIMEOUT)) # Check unfiltered wait, on success. - with timer(push_err) as t: - with self.assertLogs(level=logging.INFO): - t.start() - err = self.emcy.wait(timeout=TIMEOUT) - check_err(err) + with ( + self.assertLogs(level=logging.INFO), + mock_rx_thread(emcy, push_err), + ): + check_err(emcy.wait(timeout=TIMEOUT)) # Check filtered wait, on success. - with timer(push_err) as t: - with self.assertLogs(level=logging.INFO): - t.start() - err = self.emcy.wait(0x2001, TIMEOUT) - check_err(err) + with ( + self.assertLogs(level=logging.INFO), + mock_rx_thread(emcy, push_err), + ): + check_err(emcy.wait(0x2001, TIMEOUT)) # Check filtered wait, on timeout. - with timer(push_err) as t: - t.start() - self.assertIsNone(self.emcy.wait(0x9000, TIMEOUT)) + with mock_rx_thread(emcy, push_err): + self.assertIsNone(emcy.wait(0x9000, TIMEOUT)) def push_reset(): - self.emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 100) - - with timer(push_reset) as t: - t.start() - self.assertIsNone(self.emcy.wait(0x9000, TIMEOUT)) + emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 100) + + with mock_rx_thread(emcy, push_reset): + self.assertIsNone(emcy.wait(0x9000, TIMEOUT)) + + def test_emcy_consumer_multiple_callbacks(self): + """Test adding multiple callbacks and their execution order.""" + emcy = canopen.emcy.EmcyConsumer() + call_order = [] + emcy.add_callback(lambda err: call_order.append('callback1')) + emcy.add_callback(lambda err: call_order.append('callback2')) + emcy.add_callback(lambda err: call_order.append('callback3')) + emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + self.assertEqual(call_order, ['callback1', 'callback2', 'callback3']) + + def test_emcy_consumer_callback_exception_handling(self): + """Test that callback exceptions don't break other callbacks or the system.""" + emcy = canopen.emcy.EmcyConsumer() + successful_callbacks = [] + emcy.add_callback(lambda err: successful_callbacks.append('success1')) + emcy.add_callback( + lambda err: exec('raise ValueError("Test exception in callback")') + ) + emcy.add_callback(lambda err: successful_callbacks.append('success2')) + emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + self.assertEqual(successful_callbacks, ['success1', 'success2']) + + def test_emcy_consumer_error_reset_variants(self): + """Test different error reset code patterns.""" + emcy = canopen.emcy.EmcyConsumer() + emcy.on_emcy(0x81, b'\x01\x20\x02\x00\x01\x02\x03\x04', 1000) + emcy.on_emcy(0x81, b'\x10\x90\x01\x04\x03\x02\x01\x00', 2000) + self.assertEqual(len(emcy.active), 2) + emcy.on_emcy(0x81, b'\x00\x00\x00\x00\x00\x00\x00\x00', 3000) + self.assertEqual(len(emcy.active), 0) + emcy.on_emcy(0x81, b'\x01\x30\x02\x00\x01\x02\x03\x04', 4000) + self.assertEqual(len(emcy.active), 1) + emcy.on_emcy(0x81, b'\x99\x00\x01\x00\x00\x00\x00\x00', 5000) + self.assertEqual(len(emcy.active), 0) + + def test_emcy_consumer_wait_timeout_edge_cases(self): + """Test wait method with various timeout scenarios.""" + emcy = canopen.emcy.EmcyConsumer() + result = emcy.wait(timeout=0) + self.assertIsNone(result) + result = emcy.wait(timeout=0.001) + self.assertIsNone(result) + + def test_emcy_consumer_wait_concurrent_errors(self): + """Test wait method when multiple errors arrive concurrently.""" + emcy = canopen.emcy.EmcyConsumer() + + def push_multiple_errors(): + emcy.on_emcy(0x81, b'\x01\x20\x01\x01\x02\x03\x04\x05', 100) + emcy.on_emcy(0x81, b'\x02\x20\x01\x01\x02\x03\x04\x05', 101) + emcy.on_emcy(0x81, b'\x03\x20\x01\x01\x02\x03\x04\x05', 102) + + with ( + self.assertLogs(level=logging.INFO), + mock_rx_thread(emcy, push_multiple_errors), + ): + err = emcy.wait(0x2003, timeout=TIMEOUT) + self.assertIsNotNone(err) + self.assertEqual(err.code, 0x2003) class TestEmcyError(unittest.TestCase): + def test_emcy_error(self): - error = EmcyError(0x2001, 0x02, b'\x00\x01\x02\x03\x04', 1000) + error = canopen.emcy.EmcyError(0x2001, 0x02, b'\x00\x01\x02\x03\x04', 1000) self.assertEqual(error.code, 0x2001) self.assertEqual(error.data, b'\x00\x01\x02\x03\x04') self.assertEqual(error.register, 2) @@ -134,7 +192,7 @@ def test_emcy_error(self): def test_emcy_str(self): def check(code, expected): - err = EmcyError(code, 1, b'', 1000) + err = canopen.emcy.EmcyError(code, 1, b'', 1000) actual = str(err) self.assertEqual(actual, expected) @@ -145,7 +203,7 @@ def check(code, expected): def test_emcy_get_desc(self): def check(code, expected): - err = EmcyError(code, 1, b'', 1000) + err = canopen.emcy.EmcyError(code, 1, b'', 1000) actual = err.get_desc() self.assertEqual(actual, expected) @@ -182,6 +240,7 @@ def check(code, expected): class TestEmcyProducer(unittest.TestCase): + def setUp(self): self.txbus = can.Bus(interface="virtual") self.rxbus = can.Bus(interface="virtual") @@ -198,7 +257,7 @@ def tearDown(self): def check_response(self, expected): msg = self.rxbus.recv(TIMEOUT) - self.assertIsNotNone(msg) + assert msg is not None actual = msg.data self.assertEqual(actual, expected) @@ -220,6 +279,85 @@ def check(*args, res): check(3, res=b'\x00\x00\x03\x00\x00\x00\x00\x00') check(3, b"\xaa\xbb", res=b'\x00\x00\x03\xaa\xbb\x00\x00\x00') + def test_emcy_producer_send_edge_cases(self): + self.emcy.send(0xFFFF, 0xFF, b'\xFF\xFF\xFF\xFF\xFF') + self.check_response(b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF') + self.emcy.send(0x0000, 0x00) + self.check_response(b'\x00\x00\x00\x00\x00\x00\x00\x00') + self.emcy.send(0x1234, 0x56, b'\xAB\xCD') + self.check_response(b'\x34\x12\x56\xAB\xCD\x00\x00\x00') + self.emcy.send(0x1234, 0x56, b'\xAB\xCD\xEF\x12\x34') + self.check_response(b'\x34\x12\x56\xAB\xCD\xEF\x12\x34') + + def test_emcy_producer_reset_edge_cases(self): + self.emcy.reset(0xFF) + self.check_response(b'\x00\x00\xFF\x00\x00\x00\x00\x00') + self.emcy.reset(0xFF, b'\xFF\xFF\xFF\xFF\xFF') + self.check_response(b'\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF') + self.emcy.reset(0x12, b'\xAB\xCD') + self.check_response(b'\x00\x00\x12\xAB\xCD\x00\x00\x00') + + +class TestEmcyIntegration(unittest.TestCase): + """Integration tests for EMCY producer and consumer.""" + + def setUp(self): + self.txbus = can.Bus(interface="virtual") + self.rxbus = can.Bus(interface="virtual") + self.net = canopen.Network(self.txbus) + self.net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.net.connect() + self.rx_net = canopen.Network(self.rxbus) + self.rx_net.NOTIFIER_SHUTDOWN_TIMEOUT = 0.0 + self.rx_net.connect() + self.producer = canopen.emcy.EmcyProducer(0x081) + self.producer.network = self.net + self.consumer = canopen.emcy.EmcyConsumer() + self.rx_net.subscribe(0x081, self.consumer.on_emcy) + + def tearDown(self): + self.net.disconnect() + self.rx_net.disconnect() + self.txbus.shutdown() + self.rxbus.shutdown() + + def test_producer_consumer_integration(self): + """Test that producer and consumer work together.""" + received_errors = [] + self.consumer.add_callback(lambda err: received_errors.append(err)) + with ( + self.assertLogs(level=logging.INFO), + mock_rx_thread( + self.consumer, + lambda: self.producer.send(0x2001, 0x02, b'\x01\x02\x03\x04\x05'), + ), + ): + err = self.consumer.wait(0x2001, timeout=TIMEOUT) + self.assertIsNotNone(err) + self.assertEqual(err.code, 0x2001) + self.assertEqual(err.register, 0x02) + self.assertEqual(err.data, b'\x01\x02\x03\x04\x05') + self.assertEqual(received_errors, [err]) + + def test_producer_reset_consumer_integration(self): + """Test producer reset clears consumer active errors.""" + with ( + self.assertLogs(level=logging.INFO), + mock_rx_thread( + self.consumer, + lambda: self.producer.send(0x2001, 0x02, b'\x01\x02\x03\x04\x05'), + ), + ): + self.consumer.wait(0x2001, timeout=TIMEOUT) + self.assertEqual(len(self.consumer.active), 1) + with ( + self.assertLogs(level=logging.INFO), + mock_rx_thread(self.consumer, self.producer.reset), + ): + self.assertIsNotNone(self.consumer.wait(timeout=TIMEOUT)) + self.assertEqual(len(self.consumer.active), 0) + self.assertEqual(len(self.consumer.log), 2) + if __name__ == "__main__": unittest.main()