Skip to content

Commit aa3a16d

Browse files
fix: improve Socket Mode client stability and correctness (#1854)
1 parent 067560a commit aa3a16d

File tree

7 files changed

+103
-8
lines changed

7 files changed

+103
-8
lines changed

slack_sdk/socket_mode/builtin/internals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _parse_handshake_response(sock: ssl.SSLSocket) -> Tuple[Optional[int], dict,
145145
if len(elements) > 2:
146146
status = int(elements[1])
147147
else:
148-
elements = line.split(":")
148+
elements = line.split(":", 1)
149149
if len(elements) == 2:
150150
headers[elements[0].strip().lower()] = elements[1].strip()
151151
if line is None or len(line.strip()) == 0:
@@ -337,7 +337,7 @@ def _fetch_messages(
337337
)
338338
else:
339339
# This pattern is unexpected but set data with the expected length anyway
340-
_append_message(current_header, current_data[:current_data_length]) # type: ignore[call-arg, arg-type]
340+
_append_message(messages, current_header, current_data[:current_data_length])
341341
return messages
342342

343343
# work in progress with the current_header/current_data

slack_sdk/socket_mode/client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,17 @@ def disconnect(self) -> None:
7070
raise NotImplementedError()
7171

7272
def connect_to_new_endpoint(self, force: bool = False):
73+
acquired = False
7374
try:
74-
self.connect_operation_lock.acquire(blocking=True, timeout=5)
75-
if force or not self.is_connected():
75+
acquired = self.connect_operation_lock.acquire(blocking=True, timeout=5)
76+
if force or (acquired and not self.is_connected()):
7677
self.logger.info("Connecting to a new endpoint...")
7778
self.wss_uri = self.issue_new_wss_url()
7879
self.connect()
7980
self.logger.info("Connected to a new endpoint...")
8081
finally:
81-
self.connect_operation_lock.release()
82+
if acquired:
83+
self.connect_operation_lock.release()
8284

8385
def close(self) -> None:
8486
self.closed = True

slack_sdk/socket_mode/request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def from_dict(cls, message: dict) -> Optional["SocketModeRequest"]:
5151
return None
5252

5353
def to_dict(self) -> dict:
54-
d = {"envelope_id": self.envelope_id}
54+
d = {"type": self.type, "envelope_id": self.envelope_id}
5555
if self.payload is not None:
5656
d["payload"] = self.payload # type: ignore[assignment]
5757
return d

slack_sdk/socket_mode/websocket_client/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""websocket-client bassd Socket Mode client
1+
"""websocket-client based Socket Mode client
22
33
* https://docs.slack.dev/apis/events-api/using-socket-mode/
44
* https://docs.slack.dev/tools/python-slack-sdk/socket-mode/
@@ -229,6 +229,7 @@ def close(self) -> None:
229229
self.closed = True
230230
self.auto_reconnect_enabled = False
231231
self.disconnect()
232+
self.current_session_runner.shutdown()
232233
self.current_app_monitor.shutdown()
233234
self.message_processor.shutdown()
234235
self.message_workers.shutdown()

slack_sdk/socket_mode/websockets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""websockets bassd Socket Mode client
1+
"""websockets based Socket Mode client
22
33
* https://docs.slack.dev/apis/events-api/using-socket-mode/
44
* https://docs.slack.dev/tools/python-slack-sdk/socket-mode/

tests/slack_sdk/socket_mode/test_request.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,28 @@ def test(self):
1818
req = SocketModeRequest.from_dict(body)
1919
self.assertIsNotNone(req)
2020
self.assertEqual(req.envelope_id, "1d3c79ab-0ffb-41f3-a080-d19e85f53649")
21+
22+
def test_to_dict(self):
23+
req = SocketModeRequest(
24+
type="slash_commands",
25+
envelope_id="abc-123",
26+
payload={"text": "hello"},
27+
)
28+
self.assertDictEqual(
29+
req.to_dict(), {"type": "slash_commands", "envelope_id": "abc-123", "payload": {"text": "hello"}}
30+
)
31+
32+
def test_to_dict_from_dict_round_trip(self):
33+
expected = SocketModeRequest(
34+
type="slash_commands",
35+
envelope_id="1d3c79ab-0ffb-41f3-a080-d19e85f53649",
36+
payload={"token": "xxx", "team_id": "T111", "command": "/hello"},
37+
accepts_response_payload=True,
38+
retry_attempt=2,
39+
retry_reason="timeout",
40+
)
41+
actual = SocketModeRequest.from_dict(expected.to_dict())
42+
self.assertIsNotNone(actual)
43+
self.assertEqual(actual.type, expected.type)
44+
self.assertEqual(actual.envelope_id, expected.envelope_id)
45+
self.assertEqual(actual.payload, expected.payload)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
import ssl
3+
import unittest
4+
from threading import Lock
5+
from unittest.mock import MagicMock, patch
6+
7+
from slack_sdk.socket_mode.builtin.internals import _parse_handshake_response
8+
from slack_sdk.socket_mode.client import BaseSocketModeClient
9+
10+
11+
class TestSocketModeClient(unittest.TestCase):
12+
logger = logging.getLogger(__name__)
13+
14+
def test_connect_to_new_endpoint_does_not_release_lock_on_acquisition_timeout(self):
15+
client = BaseSocketModeClient.__new__(BaseSocketModeClient)
16+
client.logger = self.logger
17+
lock_mock = MagicMock(spec=Lock())
18+
lock_mock.acquire.return_value = False
19+
client.connect_operation_lock = lock_mock
20+
21+
client.connect_to_new_endpoint()
22+
23+
client.connect_operation_lock.release.assert_not_called()
24+
25+
def test_connect_to_new_endpoint_releases_lock_on_successful_acquisition(self):
26+
client = BaseSocketModeClient.__new__(BaseSocketModeClient)
27+
client.logger = self.logger
28+
client.connect_operation_lock = Lock()
29+
30+
with patch.object(client, client.is_connected.__name__, return_value=True):
31+
client.connect_to_new_endpoint()
32+
33+
acquired = client.connect_operation_lock.acquire(blocking=False)
34+
self.assertTrue(acquired)
35+
client.connect_operation_lock.release()
36+
37+
def test_parse_handshake_response_preserves_colons_in_header_values(self):
38+
lines = [
39+
"HTTP/1.1 101 Switching Protocols",
40+
"Upgrade: websocket",
41+
"Location: https://example.com:8080/path",
42+
"",
43+
]
44+
with patch(
45+
"slack_sdk.socket_mode.builtin.internals._read_http_response_line",
46+
side_effect=lines,
47+
):
48+
status, headers, _ = _parse_handshake_response(MagicMock(spec=ssl.SSLSocket))
49+
50+
self.assertEqual(status, 101)
51+
self.assertEqual(headers["upgrade"], "websocket")
52+
self.assertEqual(headers["location"], "https://example.com:8080/path")
53+
54+
def test_parse_handshake_response_parses_standard_headers(self):
55+
lines = [
56+
"HTTP/1.1 200 OK",
57+
"Content-Type: text/html",
58+
"",
59+
]
60+
with patch(
61+
"slack_sdk.socket_mode.builtin.internals._read_http_response_line",
62+
side_effect=lines,
63+
):
64+
status, headers, _ = _parse_handshake_response(MagicMock(spec=ssl.SSLSocket))
65+
66+
self.assertEqual(status, 200)
67+
self.assertEqual(headers["content-type"], "text/html")

0 commit comments

Comments
 (0)