Skip to content

Commit d82e8b0

Browse files
Copilotedburns
andauthored
Add CreateSessionReKeyEntryTest to cover session-map re-key cleanup paths
Co-authored-by: edburns <75821+edburns@users.noreply.github.com>
1 parent 4002f80 commit d82e8b0

1 file changed

Lines changed: 262 additions & 0 deletions

File tree

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) Microsoft Corporation. All rights reserved.
3+
*--------------------------------------------------------------------------------------------*/
4+
5+
package com.github.copilot;
6+
7+
import static org.junit.jupiter.api.Assertions.*;
8+
9+
import java.io.OutputStream;
10+
import java.lang.reflect.Field;
11+
import java.net.ServerSocket;
12+
import java.net.Socket;
13+
import java.nio.charset.StandardCharsets;
14+
import java.util.Map;
15+
import java.util.concurrent.CompletableFuture;
16+
import java.util.concurrent.ExecutionException;
17+
18+
import org.junit.jupiter.api.Test;
19+
20+
import com.fasterxml.jackson.databind.JsonNode;
21+
import com.fasterxml.jackson.databind.ObjectMapper;
22+
import com.fasterxml.jackson.databind.node.ObjectNode;
23+
import com.github.copilot.rpc.CopilotClientOptions;
24+
import com.github.copilot.rpc.PermissionHandler;
25+
import com.github.copilot.rpc.SessionConfig;
26+
27+
/**
28+
* Tests for the session-map re-key cleanup paths in CopilotClient when the
29+
* server returns a different session ID than the client-supplied one.
30+
*/
31+
class CreateSessionReKeyEntryTest {
32+
33+
private static final ObjectMapper MAPPER = JsonRpcClient.getObjectMapper();
34+
35+
/**
36+
* A connected socket pair where the server replies to "session.create" with a
37+
* configurable sessionId and then replies to "session.options.update" with
38+
* success or failure.
39+
*/
40+
private static final class ReKeyServer implements AutoCloseable {
41+
42+
final Socket clientSocket;
43+
final Socket serverSocket;
44+
final JsonRpcClient rpcClient;
45+
private volatile boolean running = true;
46+
private final Thread replyThread;
47+
48+
/** The sessionId to return in the session.create response. */
49+
private final String returnedSessionId;
50+
/** If true, the session.options.update call will fail. */
51+
private final boolean failOptionsUpdate;
52+
53+
ReKeyServer(String returnedSessionId, boolean failOptionsUpdate) throws Exception {
54+
this.returnedSessionId = returnedSessionId;
55+
this.failOptionsUpdate = failOptionsUpdate;
56+
57+
try (var ss = new ServerSocket(0)) {
58+
clientSocket = new Socket("localhost", ss.getLocalPort());
59+
serverSocket = ss.accept();
60+
}
61+
serverSocket.setSoTimeout(5000);
62+
rpcClient = JsonRpcClient.fromSocket(clientSocket);
63+
64+
replyThread = new Thread(() -> {
65+
try {
66+
var in = serverSocket.getInputStream();
67+
var out = serverSocket.getOutputStream();
68+
while (running) {
69+
// Read Content-Length header
70+
var header = new StringBuilder();
71+
int b;
72+
while ((b = in.read()) != -1) {
73+
if (b == '\n' && header.toString().endsWith("\r")) {
74+
break;
75+
}
76+
header.append((char) b);
77+
}
78+
if (b == -1)
79+
break;
80+
// Skip blank line
81+
in.read(); // '\r'
82+
in.read(); // '\n'
83+
84+
String hdr = header.toString().trim();
85+
int colon = hdr.indexOf(':');
86+
int len = Integer.parseInt(hdr.substring(colon + 1).trim());
87+
byte[] body = in.readNBytes(len);
88+
JsonNode msg = MAPPER.readTree(body);
89+
90+
String method = msg.get("method").asText();
91+
long id = msg.get("id").asLong();
92+
93+
if ("session.create".equals(method)) {
94+
// Return a response with the (possibly different) session ID
95+
ObjectNode result = MAPPER.createObjectNode();
96+
result.put("sessionId", returnedSessionId);
97+
String response = MAPPER.writeValueAsString(MAPPER.createObjectNode().put("jsonrpc", "2.0")
98+
.put("id", id).set("result", result));
99+
sendRpcMessage(out, response);
100+
} else if ("session.options.update".equals(method)) {
101+
if (failOptionsUpdate) {
102+
// Send an error response
103+
ObjectNode error = MAPPER.createObjectNode();
104+
error.put("code", -32000);
105+
error.put("message", "simulated options update failure");
106+
String response = MAPPER.writeValueAsString(MAPPER.createObjectNode()
107+
.put("jsonrpc", "2.0").put("id", id).set("error", error));
108+
sendRpcMessage(out, response);
109+
} else {
110+
// Send a success response
111+
String response = MAPPER.writeValueAsString(
112+
MAPPER.createObjectNode().put("jsonrpc", "2.0").put("id", id).set("result",
113+
MAPPER.createObjectNode().put("success", true)));
114+
sendRpcMessage(out, response);
115+
}
116+
} else {
117+
// Generic success for anything else
118+
String response = MAPPER.writeValueAsString(MAPPER.createObjectNode().put("jsonrpc", "2.0")
119+
.put("id", id).set("result", MAPPER.createObjectNode().put("success", true)));
120+
sendRpcMessage(out, response);
121+
}
122+
}
123+
} catch (Exception e) {
124+
if (running) {
125+
// Ignore expected exceptions on shutdown
126+
}
127+
}
128+
});
129+
replyThread.setDaemon(true);
130+
replyThread.start();
131+
}
132+
133+
private static void sendRpcMessage(OutputStream out, String json) throws Exception {
134+
byte[] bytes = json.getBytes(StandardCharsets.UTF_8);
135+
String header = "Content-Length: " + bytes.length + "\r\n\r\n";
136+
out.write(header.getBytes(StandardCharsets.UTF_8));
137+
out.write(bytes);
138+
out.flush();
139+
}
140+
141+
@Override
142+
public void close() throws Exception {
143+
running = false;
144+
rpcClient.close();
145+
clientSocket.close();
146+
serverSocket.close();
147+
replyThread.join(3000);
148+
}
149+
}
150+
151+
@SuppressWarnings("unchecked")
152+
private static Map<String, CopilotSession> getSessionsMap(CopilotClient client) throws Exception {
153+
Field f = CopilotClient.class.getDeclaredField("sessions");
154+
f.setAccessible(true);
155+
return (Map<String, CopilotSession>) f.get(client);
156+
}
157+
158+
private static void injectConnection(CopilotClient client, JsonRpcClient rpc) throws Exception {
159+
// Build a Connection record via the private record constructor
160+
Class<?> connClass = null;
161+
for (Class<?> c : CopilotClient.class.getDeclaredClasses()) {
162+
if (c.getSimpleName().equals("Connection")) {
163+
connClass = c;
164+
break;
165+
}
166+
}
167+
assertNotNull(connClass, "Could not find Connection inner class");
168+
169+
var ctor = connClass.getDeclaredConstructors()[0];
170+
ctor.setAccessible(true);
171+
// Connection(JsonRpcClient rpc, Process process, ServerRpc serverRpc)
172+
Object connection = ctor.newInstance(rpc, null, null);
173+
174+
Field f = CopilotClient.class.getDeclaredField("connectionFuture");
175+
f.setAccessible(true);
176+
f.set(client, CompletableFuture.completedFuture(connection));
177+
}
178+
179+
@Test
180+
void createSessionReKeyEntry_successfulReKey_removesOldKeyAndAddsNewKey() throws Exception {
181+
String clientSessionId = "client-supplied-id";
182+
String serverSessionId = "server-returned-id";
183+
184+
try (var server = new ReKeyServer(serverSessionId, false)) {
185+
var client = new CopilotClient(new CopilotClientOptions().setAutoStart(false));
186+
injectConnection(client, server.rpcClient);
187+
188+
var config = new SessionConfig().setSessionId(clientSessionId)
189+
.setOnPermissionRequest(PermissionHandler.APPROVE_ALL);
190+
191+
CopilotSession session = client.createSession(config).get();
192+
193+
Map<String, CopilotSession> sessions = getSessionsMap(client);
194+
195+
// The old client-supplied key should be removed
196+
assertNull(sessions.get(clientSessionId),
197+
"Old client-supplied sessionId should be removed from sessions map after re-key");
198+
// The new server-returned key should be present
199+
assertSame(session, sessions.get(serverSessionId),
200+
"Server-returned sessionId should be the key in sessions map");
201+
// The session object should report the server-returned ID
202+
assertEquals(serverSessionId, session.getSessionId(),
203+
"Session should report the server-returned sessionId");
204+
205+
client.close();
206+
}
207+
}
208+
209+
@Test
210+
void createSessionReKeyEntry_failureAfterReKey_removesBothKeys() throws Exception {
211+
String clientSessionId = "client-supplied-id";
212+
String serverSessionId = "server-returned-id";
213+
214+
try (var server = new ReKeyServer(serverSessionId, true)) {
215+
var client = new CopilotClient(new CopilotClientOptions().setAutoStart(false));
216+
injectConnection(client, server.rpcClient);
217+
218+
// Set skipCustomInstructions so that session.options.update is actually invoked
219+
var config = new SessionConfig().setSessionId(clientSessionId)
220+
.setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setSkipCustomInstructions(true);
221+
222+
// The session.options.update will fail, triggering the exceptionally handler
223+
ExecutionException ex = assertThrows(ExecutionException.class, () -> client.createSession(config).get());
224+
assertNotNull(ex.getCause());
225+
226+
Map<String, CopilotSession> sessions = getSessionsMap(client);
227+
228+
// Both the original and re-keyed entries should be cleaned up
229+
assertNull(sessions.get(clientSessionId),
230+
"Original client-supplied sessionId should be removed on failure");
231+
assertNull(sessions.get(serverSessionId),
232+
"Re-keyed server-returned sessionId should be removed on failure");
233+
assertTrue(sessions.isEmpty(), "Sessions map should be empty after failed create with re-key");
234+
235+
client.close();
236+
}
237+
}
238+
239+
@Test
240+
void createSessionReKeyEntry_noReKey_sameIdKept() throws Exception {
241+
String sessionId = "same-id-for-both";
242+
243+
try (var server = new ReKeyServer(sessionId, false)) {
244+
var client = new CopilotClient(new CopilotClientOptions().setAutoStart(false));
245+
injectConnection(client, server.rpcClient);
246+
247+
var config = new SessionConfig().setSessionId(sessionId)
248+
.setOnPermissionRequest(PermissionHandler.APPROVE_ALL);
249+
250+
CopilotSession session = client.createSession(config).get();
251+
252+
Map<String, CopilotSession> sessions = getSessionsMap(client);
253+
254+
// When IDs match, the session stays under the original key
255+
assertSame(session, sessions.get(sessionId),
256+
"Session should remain under original key when server returns same ID");
257+
assertEquals(1, sessions.size(), "Should have exactly one entry in sessions map");
258+
259+
client.close();
260+
}
261+
}
262+
}

0 commit comments

Comments
 (0)