Skip to content

Commit 1568017

Browse files
authored
Additional LongPollWait supporting code (#11)
1 parent 44f82fa commit 1568017

4 files changed

Lines changed: 132 additions & 25 deletions

File tree

client.go

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ type Client struct {
3131
func NewClient(useragent string, dnServer string) *Client {
3232
return &Client{
3333
client: &http.Client{
34-
Timeout: 1 * time.Minute,
34+
Timeout: 2 * time.Minute,
3535
Transport: &uaTransport{
3636
T: &http.Transport{
37-
Proxy: http.ProxyFromEnvironment,
38-
TLSHandshakeTimeout: 10 * time.Second,
39-
ResponseHeaderTimeout: 10 * time.Second,
37+
Proxy: http.ProxyFromEnvironment,
38+
TLSHandshakeTimeout: 10 * time.Second,
4039
DialContext: (&net.Dialer{
4140
Timeout: 10 * time.Second,
4241
}).DialContext,
@@ -48,9 +47,8 @@ func NewClient(useragent string, dnServer string) *Client {
4847
Timeout: 15 * time.Minute,
4948
Transport: &uaTransport{
5049
T: &http.Transport{
51-
Proxy: http.ProxyFromEnvironment,
52-
TLSHandshakeTimeout: 10 * time.Second,
53-
ResponseHeaderTimeout: 10 * time.Second,
50+
Proxy: http.ProxyFromEnvironment,
51+
TLSHandshakeTimeout: 10 * time.Second,
5452
DialContext: (&net.Dialer{
5553
Timeout: 10 * time.Second,
5654
}).DialContext,
@@ -185,24 +183,24 @@ func (c *Client) CheckForUpdate(ctx context.Context, creds Credentials) (bool, e
185183

186184
// LongPollWait sends a signed message to a DNClient API endpoint that will block, returning only
187185
// if there is an action the client should take before the timeout (config updates, debug commands)
188-
func (c *Client) LongPollWait(ctx context.Context, creds Credentials, supportedActions []string) (string, error) {
186+
func (c *Client) LongPollWait(ctx context.Context, creds Credentials, supportedActions []string) (*message.LongPollWaitResponse, error) {
189187
value, err := json.Marshal(message.LongPollWaitRequest{
190188
SupportedActions: supportedActions,
191189
})
192190
if err != nil {
193-
return "", fmt.Errorf("failed to marshal DNClient message: %s", err)
191+
return nil, fmt.Errorf("failed to marshal DNClient message: %s", err)
194192
}
195193

196194
respBody, err := c.postDNClient(ctx, message.LongPollWait, value, creds.HostID, creds.Counter, creds.PrivateKey)
197195
if err != nil {
198-
return "", fmt.Errorf("failed to post message to dnclient api: %w", err)
196+
return nil, fmt.Errorf("failed to post message to dnclient api: %w", err)
199197
}
200198
result := message.LongPollWaitResponseWrapper{}
201199
err = json.Unmarshal(respBody, &result)
202200
if err != nil {
203-
return "", fmt.Errorf("failed to interpret API response: %s", err)
201+
return nil, fmt.Errorf("failed to interpret API response: %s", err)
204202
}
205-
return result.Data.Action, nil
203+
return &result.Data, nil
206204
}
207205

208206
// DoUpdate sends a signed message to the DNClient API to fetch the new configuration update. During this call a new
@@ -282,6 +280,19 @@ func (c *Client) DoUpdate(ctx context.Context, creds Credentials) ([]byte, []byt
282280
return result.Config, dhPrivkeyPEM, newCreds, nil
283281
}
284282

283+
func (c *Client) CommandResponse(ctx context.Context, creds Credentials, responseToken string, response any) error {
284+
value, err := json.Marshal(message.CommandResponseRequest{
285+
ResponseToken: responseToken,
286+
Response: response,
287+
})
288+
if err != nil {
289+
return fmt.Errorf("failed to marshal DNClient message: %s", err)
290+
}
291+
292+
_, err = c.postDNClient(ctx, message.CommandResponse, value, creds.HostID, creds.Counter, creds.PrivateKey)
293+
return err
294+
}
295+
285296
func (c *Client) StreamCommandResponse(ctx context.Context, creds Credentials, responseToken string) (*StreamController, error) {
286297
value, err := json.Marshal(message.CommandResponseRequest{
287298
ResponseToken: responseToken,
@@ -313,13 +324,11 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
313324
sc := &StreamController{w: pw, done: done}
314325

315326
go func() {
316-
defer func() {
317-
close(done)
318-
}()
327+
defer close(done)
319328

320329
resp, err := c.streamingClient.Do(req)
321330
if err != nil {
322-
sc.err.Store(fmt.Errorf("failed to call dnclient endpoint: %s", err))
331+
sc.err.Store(fmt.Errorf("failed to call dnclient endpoint: %w", err))
323332
return
324333
}
325334
defer resp.Body.Close()
@@ -362,7 +371,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
362371
}
363372
resp, err := c.client.Do(req)
364373
if err != nil {
365-
return nil, fmt.Errorf("failed to call dnclient endpoint: %s", err)
374+
return nil, fmt.Errorf("failed to call dnclient endpoint: %w", err)
366375
}
367376
defer resp.Body.Close()
368377

@@ -409,14 +418,18 @@ func (sc *StreamController) Err() error {
409418
return err.(error)
410419
}
411420

412-
// Write implements the io.Writer interface for StreamController. It writes to the request body. It never returns an
413-
// error. If the StreamController has already encountered an error, Write will return immediately without writing.
414-
// To check for errors, call Err.
415-
func (sc *StreamController) Write(p []byte) (n int, err error) {
421+
// Write implements the io.Writer interface for StreamController. It writes to the request body. If the StreamController
422+
// has already encountered an error, it will be returned and nothing will be written.
423+
func (sc *StreamController) Write(p []byte) (int, error) {
416424
if sc.Err() != nil {
417425
return 0, sc.Err()
418426
}
419-
return sc.w.Write(p)
427+
428+
n, err := sc.w.Write(p)
429+
if err != nil {
430+
sc.err.Store(err)
431+
}
432+
return n, err
420433
}
421434

422435
// Close closes the StreamController, signaling that the request body is complete and the response can be read.

client_test.go

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,97 @@ func TestDoUpdate(t *testing.T) {
317317

318318
}
319319

320+
func TestCommandResponse(t *testing.T) {
321+
t.Parallel()
322+
323+
useragent := "testClient"
324+
ts := dnapitest.NewServer(useragent)
325+
t.Cleanup(func() { ts.Close() })
326+
327+
ca, _ := dnapitest.NebulaCACert()
328+
caPEM, err := ca.MarshalToPEM()
329+
require.NoError(t, err)
330+
331+
c := NewClient(useragent, ts.URL)
332+
333+
code := "foobar"
334+
ts.ExpectEnrollment(code, func(req message.EnrollRequest) []byte {
335+
cfg, err := yaml.Marshal(m{
336+
// we need to send this or we'll get an error from the api client
337+
"pki": m{"ca": string(caPEM)},
338+
// here we reflect values back to the client for test purposes
339+
"test": m{"code": req.Code, "dhPubkey": req.DHPubkey},
340+
})
341+
if err != nil {
342+
return jsonMarshal(message.EnrollResponse{
343+
Errors: message.APIErrors{{
344+
Code: "ERR_FAILED_TO_MARSHAL_YAML",
345+
Message: "failed to marshal test response config",
346+
}},
347+
})
348+
}
349+
350+
return jsonMarshal(message.EnrollResponse{
351+
Data: message.EnrollResponseData{
352+
HostID: "foobar",
353+
Counter: 1,
354+
Config: cfg,
355+
TrustedKeys: cert.MarshalEd25519PublicKey(ca.Details.PublicKey),
356+
Organization: message.EnrollResponseDataOrg{
357+
ID: "foobaz",
358+
Name: "foobar's foo org",
359+
},
360+
},
361+
})
362+
})
363+
364+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
365+
defer cancel()
366+
config, pkey, creds, _, err := c.Enroll(ctx, testutil.NewTestLogger(), "foobar")
367+
require.NoError(t, err)
368+
369+
// make sure all credential values were set
370+
assert.NotEmpty(t, creds.HostID)
371+
assert.NotEmpty(t, creds.PrivateKey)
372+
assert.NotEmpty(t, creds.TrustedKeys)
373+
assert.NotEmpty(t, creds.Counter)
374+
375+
// make sure we got a config back
376+
assert.NotEmpty(t, config)
377+
assert.NotEmpty(t, pkey)
378+
379+
// This time sign the response with the correct CA key.
380+
responseToken := "abc123"
381+
res := map[string]any{"msg": "Hello, world!"}
382+
ts.ExpectRequest(message.CommandResponse, http.StatusOK, func(r message.RequestWrapper) []byte {
383+
var val map[string]any
384+
err := json.Unmarshal(r.Value, &val)
385+
require.NoError(t, err)
386+
require.Contains(t, val, "responseToken")
387+
require.Equal(t, responseToken, val["responseToken"])
388+
require.Contains(t, val, "response")
389+
require.Equal(t, res, val["response"])
390+
return jsonMarshal(struct{}{})
391+
})
392+
393+
err = c.CommandResponse(context.Background(), *creds, responseToken, res)
394+
require.NoError(t, err)
395+
396+
// Test error handling
397+
errorMsg := "sample error"
398+
ts.ExpectRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
399+
return jsonMarshal(message.EnrollResponse{
400+
Errors: message.APIErrors{{
401+
Code: "ERR_INVALID_VALUE",
402+
Message: errorMsg,
403+
}},
404+
})
405+
})
406+
407+
err = c.CommandResponse(context.Background(), *creds, "responseToken", map[string]any{"msg": "Hello, world!"})
408+
require.Error(t, err)
409+
}
410+
320411
func TestStreamCommandResponse(t *testing.T) {
321412
t.Parallel()
322413

@@ -451,7 +542,7 @@ func TestTimeout(t *testing.T) {
451542
useragent := "TestTimeout agent"
452543
c := NewClient(useragent, ts.URL)
453544
// The default timeout is 1 minutes. Assert the default value.
454-
assert.Equal(t, c.client.Timeout, 1*time.Minute)
545+
assert.Equal(t, c.client.Timeout, 2*time.Minute)
455546
// The default streaming timeout is 15 minutes. Assert the default value.
456547
assert.Equal(t, c.streamingClient.Timeout, 15*time.Minute)
457548
// Overwrite the default value with a 10 millisecond timeout for test brevity.

dnapitest/dnapitest.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func (s *Server) handlerDNClient(w http.ResponseWriter, r *http.Request) {
156156
// Require the expected request type, otherwise we have derailed.
157157
if msg.Type != res.expectedType {
158158
s.errors = append(s.errors, fmt.Errorf("%s is not expected message type %s", msg.Type, res.expectedType))
159-
http.Error(w, "unexpected message type", http.StatusInternalServerError)
159+
http.Error(w, fmt.Sprintf("unexpected message type %s, wanted %s", msg.Type, res.expectedType), http.StatusInternalServerError)
160160
return
161161
}
162162

message/message.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package message
22

33
import (
4+
"encoding/json"
45
"errors"
56
"strings"
67
"time"
@@ -87,7 +88,8 @@ type LongPollWaitRequest struct {
8788

8889
// LongPollWaitResponse is the response message associated with a LongPollWait call.
8990
type LongPollWaitResponse struct {
90-
Action string `json:"action"` // e.g. NoOp, StreamLogs, DoUpdate
91+
Action json.RawMessage `json:"action"` // e.g. NoOp, StreamLogs, DoUpdate
92+
ResponseToken string `json:"responseToken"`
9193
}
9294

9395
// CommandResponseResponseWrapper contains a response to CommandResponse inside "data."
@@ -98,6 +100,7 @@ type CommandResponseResponseWrapper struct {
98100
// CommandResponseRequest is the request message associated with a CommandResponse call.
99101
type CommandResponseRequest struct {
100102
ResponseToken string `json:"responseToken"`
103+
Response any `json:"response"`
101104
}
102105

103106
// DNClientCommandResponseResponse is the response message associated with a CommandResponse call.

0 commit comments

Comments
 (0)