Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions slack_sdk/socket_mode/builtin/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _parse_handshake_response(sock: ssl.SSLSocket) -> Tuple[Optional[int], dict,
if len(elements) > 2:
status = int(elements[1])
else:
elements = line.split(":")
elements = line.split(":", 1)
if len(elements) == 2:
headers[elements[0].strip().lower()] = elements[1].strip()
if line is None or len(line.strip()) == 0:
Expand Down Expand Up @@ -337,7 +337,7 @@ def _fetch_messages(
)
else:
# This pattern is unexpected but set data with the expected length anyway
_append_message(current_header, current_data[:current_data_length]) # type: ignore[call-arg, arg-type]
_append_message(messages, current_header, current_data[:current_data_length])
return messages

# work in progress with the current_header/current_data
Expand Down
8 changes: 5 additions & 3 deletions slack_sdk/socket_mode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ def disconnect(self) -> None:
raise NotImplementedError()

def connect_to_new_endpoint(self, force: bool = False):
acquired = False
try:
self.connect_operation_lock.acquire(blocking=True, timeout=5)
if force or not self.is_connected():
acquired = self.connect_operation_lock.acquire(blocking=True, timeout=5)
if force or (acquired and not self.is_connected()):
self.logger.info("Connecting to a new endpoint...")
self.wss_uri = self.issue_new_wss_url()
self.connect()
self.logger.info("Connected to a new endpoint...")
finally:
self.connect_operation_lock.release()
if acquired:
self.connect_operation_lock.release()

def close(self) -> None:
self.closed = True
Expand Down
2 changes: 1 addition & 1 deletion slack_sdk/socket_mode/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def from_dict(cls, message: dict) -> Optional["SocketModeRequest"]:
return None

def to_dict(self) -> dict:
d = {"envelope_id": self.envelope_id}
d = {"type": self.type, "envelope_id": self.envelope_id}
if self.payload is not None:
d["payload"] = self.payload # type: ignore[assignment]
return d
3 changes: 2 additions & 1 deletion slack_sdk/socket_mode/websocket_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""websocket-client bassd Socket Mode client
"""websocket-client based Socket Mode client

* https://docs.slack.dev/apis/events-api/using-socket-mode/
* https://docs.slack.dev/tools/python-slack-sdk/socket-mode/
Expand Down Expand Up @@ -229,6 +229,7 @@ def close(self) -> None:
self.closed = True
self.auto_reconnect_enabled = False
self.disconnect()
self.current_session_runner.shutdown()
self.current_app_monitor.shutdown()
self.message_processor.shutdown()
self.message_workers.shutdown()
Expand Down
2 changes: 1 addition & 1 deletion slack_sdk/socket_mode/websockets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""websockets bassd Socket Mode client
"""websockets based Socket Mode client
* https://docs.slack.dev/apis/events-api/using-socket-mode/
* https://docs.slack.dev/tools/python-slack-sdk/socket-mode/
Expand Down
25 changes: 25 additions & 0 deletions tests/slack_sdk/socket_mode/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,28 @@ def test(self):
req = SocketModeRequest.from_dict(body)
self.assertIsNotNone(req)
self.assertEqual(req.envelope_id, "1d3c79ab-0ffb-41f3-a080-d19e85f53649")

def test_to_dict(self):
req = SocketModeRequest(
type="slash_commands",
envelope_id="abc-123",
payload={"text": "hello"},
)
self.assertDictEqual(
req.to_dict(), {"type": "slash_commands", "envelope_id": "abc-123", "payload": {"text": "hello"}}
)

def test_to_dict_from_dict_round_trip(self):
expected = SocketModeRequest(
type="slash_commands",
envelope_id="1d3c79ab-0ffb-41f3-a080-d19e85f53649",
payload={"token": "xxx", "team_id": "T111", "command": "/hello"},
accepts_response_payload=True,
retry_attempt=2,
retry_reason="timeout",
)
actual = SocketModeRequest.from_dict(expected.to_dict())
self.assertIsNotNone(actual)
self.assertEqual(actual.type, expected.type)
self.assertEqual(actual.envelope_id, expected.envelope_id)
self.assertEqual(actual.payload, expected.payload)
68 changes: 68 additions & 0 deletions tests/slack_sdk/socket_mode/test_socket_mode_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import ssl
import unittest
from threading import Lock
from unittest.mock import patch, MagicMock, create_autospec

from slack_sdk.socket_mode.client import BaseSocketModeClient
from slack_sdk.socket_mode.builtin.internals import (
_parse_handshake_response,
_fetch_messages,
)


class TestSocketModeClient(unittest.TestCase):
logger = logging.getLogger(__name__)

def test_connect_to_new_endpoint_does_not_release_lock_on_acquisition_timeout(self):
client = BaseSocketModeClient.__new__(BaseSocketModeClient)
client.logger = self.logger
client.connect_operation_lock = create_autospec(Lock(), acquire=MagicMock(return_value=False))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this is where unit tests are failing. this is the fix according to claude

Suggested change
client.connect_operation_lock = create_autospec(Lock(), acquire=MagicMock(return_value=False))
mock_lock = create_autospec(Lock())
mock_lock.acquire.return_value = False
client.connect_operation_lock = mock_lock

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats spot on 💯 thanks for sharing the solution 🥇


client.connect_to_new_endpoint()

client.connect_operation_lock.release.assert_not_called()

def test_connect_to_new_endpoint_releases_lock_on_successful_acquisition(self):
client = BaseSocketModeClient.__new__(BaseSocketModeClient)
client.logger = self.logger
client.connect_operation_lock = Lock()

with patch.object(client, client.is_connected.__name__, return_value=True):
client.connect_to_new_endpoint()

acquired = client.connect_operation_lock.acquire(blocking=False)
self.assertTrue(acquired)
client.connect_operation_lock.release()

def test_parse_handshake_response_preserves_colons_in_header_values(self):
lines = [
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Location: https://example.com:8080/path",
"",
]
with patch(
"slack_sdk.socket_mode.builtin.internals._read_http_response_line",
side_effect=lines,
):
status, headers, _ = _parse_handshake_response(MagicMock(spec=ssl.SSLSocket))

self.assertEqual(status, 101)
self.assertEqual(headers["upgrade"], "websocket")
self.assertEqual(headers["location"], "https://example.com:8080/path")

def test_parse_handshake_response_parses_standard_headers(self):
lines = [
"HTTP/1.1 200 OK",
"Content-Type: text/html",
"",
]
with patch(
"slack_sdk.socket_mode.builtin.internals._read_http_response_line",
side_effect=lines,
):
status, headers, _ = _parse_handshake_response(MagicMock(spec=ssl.SSLSocket))

self.assertEqual(status, 200)
self.assertEqual(headers["content-type"], "text/html")
Loading