|
45 | 45 | ) |
46 | 46 | from .generated.rpc import ModelCapabilitiesOverride as _RpcModelCapabilitiesOverride |
47 | 47 | from .generated.session_events import ( |
| 48 | + AssistantMessageData, |
48 | 49 | CapabilitiesChangedData, |
49 | 50 | CommandExecuteData, |
50 | 51 | ElicitationRequestedData, |
51 | 52 | ExternalToolRequestedData, |
52 | 53 | PermissionRequest, |
53 | 54 | PermissionRequestedData, |
54 | 55 | SessionEvent, |
| 56 | + SessionErrorData, |
55 | 57 | SessionEventType, |
| 58 | + SessionIdleData, |
56 | 59 | session_event_from_dict, |
57 | 60 | ) |
58 | 61 | from .tools import Tool, ToolHandler, ToolInvocation, ToolResult |
@@ -1134,24 +1137,25 @@ async def send_and_wait( |
1134 | 1137 | Example: |
1135 | 1138 | >>> from copilot.generated.session_events import AssistantMessageData |
1136 | 1139 | >>> response = await session.send_and_wait("What is 2+2?") |
1137 | | - >>> if response and isinstance(response.data, AssistantMessageData): |
1138 | | - ... print(response.data.content) |
| 1140 | + >>> if response: |
| 1141 | + ... match response.data: |
| 1142 | + ... case AssistantMessageData() as data: |
| 1143 | + ... print(data.content) |
1139 | 1144 | """ |
1140 | 1145 | idle_event = asyncio.Event() |
1141 | 1146 | error_event: Exception | None = None |
1142 | 1147 | last_assistant_message: SessionEvent | None = None |
1143 | 1148 |
|
1144 | 1149 | def handler(event: SessionEventTypeAlias) -> None: |
1145 | 1150 | nonlocal last_assistant_message, error_event |
1146 | | - if event.type == SessionEventType.ASSISTANT_MESSAGE: |
1147 | | - last_assistant_message = event |
1148 | | - elif event.type == SessionEventType.SESSION_IDLE: |
1149 | | - idle_event.set() |
1150 | | - elif event.type == SessionEventType.SESSION_ERROR: |
1151 | | - error_event = Exception( |
1152 | | - f"Session error: {getattr(event.data, 'message', str(event.data))}" |
1153 | | - ) |
1154 | | - idle_event.set() |
| 1151 | + match event.data: |
| 1152 | + case AssistantMessageData(): |
| 1153 | + last_assistant_message = event |
| 1154 | + case SessionIdleData(): |
| 1155 | + idle_event.set() |
| 1156 | + case SessionErrorData() as data: |
| 1157 | + error_event = Exception(f"Session error: {data.message or str(data)}") |
| 1158 | + idle_event.set() |
1155 | 1159 |
|
1156 | 1160 | unsubscribe = self.on(handler) |
1157 | 1161 | try: |
@@ -1183,10 +1187,11 @@ def on(self, handler: Callable[[SessionEvent], None]) -> Callable[[], None]: |
1183 | 1187 | Example: |
1184 | 1188 | >>> from copilot.generated.session_events import AssistantMessageData, SessionErrorData |
1185 | 1189 | >>> def handle_event(event): |
1186 | | - ... if isinstance(event.data, AssistantMessageData): |
1187 | | - ... print(f"Assistant: {event.data.content}") |
1188 | | - ... elif isinstance(event.data, SessionErrorData): |
1189 | | - ... print(f"Error: {event.data.message}") |
| 1190 | + ... match event.data: |
| 1191 | + ... case AssistantMessageData() as data: |
| 1192 | + ... print(f"Assistant: {data.content}") |
| 1193 | + ... case SessionErrorData() as data: |
| 1194 | + ... print(f"Error: {data.message}") |
1190 | 1195 | >>> unsubscribe = session.on(handle_event) |
1191 | 1196 | >>> # Later, to stop receiving events: |
1192 | 1197 | >>> unsubscribe() |
@@ -1232,90 +1237,89 @@ def _handle_broadcast_event(self, event: SessionEvent) -> None: |
1232 | 1237 | Implements the protocol v3 broadcast model where tool calls and permission requests |
1233 | 1238 | are broadcast as session events to all clients. |
1234 | 1239 | """ |
1235 | | - data = event.data |
1236 | | - |
1237 | | - if isinstance(data, ExternalToolRequestedData): |
1238 | | - request_id = data.request_id |
1239 | | - tool_name = data.tool_name |
1240 | | - if not request_id or not tool_name: |
1241 | | - return |
1242 | | - |
1243 | | - handler = self._get_tool_handler(tool_name) |
1244 | | - if not handler: |
1245 | | - return # This client doesn't handle this tool; another client will. |
1246 | | - |
1247 | | - tool_call_id = data.tool_call_id or "" |
1248 | | - arguments = data.arguments |
1249 | | - tp = getattr(data, "traceparent", None) |
1250 | | - ts = getattr(data, "tracestate", None) |
1251 | | - asyncio.ensure_future( |
1252 | | - self._execute_tool_and_respond( |
1253 | | - request_id, tool_name, tool_call_id, arguments, handler, tp, ts |
| 1240 | + match event.data: |
| 1241 | + case ExternalToolRequestedData() as data: |
| 1242 | + request_id = data.request_id |
| 1243 | + tool_name = data.tool_name |
| 1244 | + if not request_id or not tool_name: |
| 1245 | + return |
| 1246 | + |
| 1247 | + handler = self._get_tool_handler(tool_name) |
| 1248 | + if not handler: |
| 1249 | + return # This client doesn't handle this tool; another client will. |
| 1250 | + |
| 1251 | + tool_call_id = data.tool_call_id or "" |
| 1252 | + arguments = data.arguments |
| 1253 | + tp = getattr(data, "traceparent", None) |
| 1254 | + ts = getattr(data, "tracestate", None) |
| 1255 | + asyncio.ensure_future( |
| 1256 | + self._execute_tool_and_respond( |
| 1257 | + request_id, tool_name, tool_call_id, arguments, handler, tp, ts |
| 1258 | + ) |
1254 | 1259 | ) |
1255 | | - ) |
1256 | 1260 |
|
1257 | | - elif isinstance(data, PermissionRequestedData): |
1258 | | - request_id = data.request_id |
1259 | | - permission_request = data.permission_request |
1260 | | - if not request_id or not permission_request: |
1261 | | - return |
| 1261 | + case PermissionRequestedData() as data: |
| 1262 | + request_id = data.request_id |
| 1263 | + permission_request = data.permission_request |
| 1264 | + if not request_id or not permission_request: |
| 1265 | + return |
1262 | 1266 |
|
1263 | | - resolved_by_hook = getattr(data, "resolved_by_hook", None) |
1264 | | - if resolved_by_hook: |
1265 | | - return # Already resolved by a permissionRequest hook; no client action needed. |
| 1267 | + resolved_by_hook = getattr(data, "resolved_by_hook", None) |
| 1268 | + if resolved_by_hook: |
| 1269 | + return # Already resolved by a permissionRequest hook; no client action needed. |
1266 | 1270 |
|
1267 | | - with self._permission_handler_lock: |
1268 | | - perm_handler = self._permission_handler |
1269 | | - if not perm_handler: |
1270 | | - return # This client doesn't handle permissions; another client will. |
| 1271 | + with self._permission_handler_lock: |
| 1272 | + perm_handler = self._permission_handler |
| 1273 | + if not perm_handler: |
| 1274 | + return # This client doesn't handle permissions; another client will. |
1271 | 1275 |
|
1272 | | - asyncio.ensure_future( |
1273 | | - self._execute_permission_and_respond(request_id, permission_request, perm_handler) |
1274 | | - ) |
| 1276 | + asyncio.ensure_future( |
| 1277 | + self._execute_permission_and_respond(request_id, permission_request, perm_handler) |
| 1278 | + ) |
1275 | 1279 |
|
1276 | | - elif isinstance(data, CommandExecuteData): |
1277 | | - request_id = data.request_id |
1278 | | - command_name = data.command_name |
1279 | | - command = data.command |
1280 | | - args = data.args |
1281 | | - if not request_id or not command_name: |
1282 | | - return |
1283 | | - asyncio.ensure_future( |
1284 | | - self._execute_command_and_respond( |
1285 | | - request_id, command_name, command or "", args or "" |
| 1280 | + case CommandExecuteData() as data: |
| 1281 | + request_id = data.request_id |
| 1282 | + command_name = data.command_name |
| 1283 | + command = data.command |
| 1284 | + args = data.args |
| 1285 | + if not request_id or not command_name: |
| 1286 | + return |
| 1287 | + asyncio.ensure_future( |
| 1288 | + self._execute_command_and_respond( |
| 1289 | + request_id, command_name, command or "", args or "" |
| 1290 | + ) |
1286 | 1291 | ) |
1287 | | - ) |
1288 | 1292 |
|
1289 | | - elif isinstance(data, ElicitationRequestedData): |
1290 | | - with self._elicitation_handler_lock: |
1291 | | - handler = self._elicitation_handler |
1292 | | - if not handler: |
1293 | | - return |
1294 | | - request_id = data.request_id |
1295 | | - if not request_id: |
1296 | | - return |
1297 | | - context: ElicitationContext = { |
1298 | | - "session_id": self.session_id, |
1299 | | - "message": data.message or "", |
1300 | | - } |
1301 | | - if data.requested_schema is not None: |
1302 | | - context["requestedSchema"] = data.requested_schema.to_dict() |
1303 | | - if data.mode is not None: |
1304 | | - context["mode"] = data.mode.value |
1305 | | - if data.elicitation_source is not None: |
1306 | | - context["elicitationSource"] = data.elicitation_source |
1307 | | - if data.url is not None: |
1308 | | - context["url"] = data.url |
1309 | | - asyncio.ensure_future(self._handle_elicitation_request(context, request_id)) |
1310 | | - |
1311 | | - elif isinstance(data, CapabilitiesChangedData): |
1312 | | - cap: SessionCapabilities = {} |
1313 | | - if data.ui is not None: |
1314 | | - ui_cap: SessionUiCapabilities = {} |
1315 | | - if data.ui.elicitation is not None: |
1316 | | - ui_cap["elicitation"] = data.ui.elicitation |
1317 | | - cap["ui"] = ui_cap |
1318 | | - self._capabilities = {**self._capabilities, **cap} |
| 1293 | + case ElicitationRequestedData() as data: |
| 1294 | + with self._elicitation_handler_lock: |
| 1295 | + handler = self._elicitation_handler |
| 1296 | + if not handler: |
| 1297 | + return |
| 1298 | + request_id = data.request_id |
| 1299 | + if not request_id: |
| 1300 | + return |
| 1301 | + context: ElicitationContext = { |
| 1302 | + "session_id": self.session_id, |
| 1303 | + "message": data.message or "", |
| 1304 | + } |
| 1305 | + if data.requested_schema is not None: |
| 1306 | + context["requestedSchema"] = data.requested_schema.to_dict() |
| 1307 | + if data.mode is not None: |
| 1308 | + context["mode"] = data.mode.value |
| 1309 | + if data.elicitation_source is not None: |
| 1310 | + context["elicitationSource"] = data.elicitation_source |
| 1311 | + if data.url is not None: |
| 1312 | + context["url"] = data.url |
| 1313 | + asyncio.ensure_future(self._handle_elicitation_request(context, request_id)) |
| 1314 | + |
| 1315 | + case CapabilitiesChangedData() as data: |
| 1316 | + cap: SessionCapabilities = {} |
| 1317 | + if data.ui is not None: |
| 1318 | + ui_cap: SessionUiCapabilities = {} |
| 1319 | + if data.ui.elicitation is not None: |
| 1320 | + ui_cap["elicitation"] = data.ui.elicitation |
| 1321 | + cap["ui"] = ui_cap |
| 1322 | + self._capabilities = {**self._capabilities, **cap} |
1319 | 1323 |
|
1320 | 1324 | async def _execute_tool_and_respond( |
1321 | 1325 | self, |
@@ -1807,8 +1811,9 @@ async def get_messages(self) -> list[SessionEvent]: |
1807 | 1811 | >>> from copilot.generated.session_events import AssistantMessageData |
1808 | 1812 | >>> events = await session.get_messages() |
1809 | 1813 | >>> for event in events: |
1810 | | - ... if isinstance(event.data, AssistantMessageData): |
1811 | | - ... print(f"Assistant: {event.data.content}") |
| 1814 | + ... match event.data: |
| 1815 | + ... case AssistantMessageData() as data: |
| 1816 | + ... print(f"Assistant: {data.content}") |
1812 | 1817 | """ |
1813 | 1818 | response = await self._client.request("session.getMessages", {"sessionId": self.session_id}) |
1814 | 1819 | # Convert dict events to SessionEvent objects |
|
0 commit comments