Skip to content

Commit d7fa4c8

Browse files
authored
Add support for CommandResponse streams (#9)
1 parent 5dd626b commit d7fa4c8

6 files changed

Lines changed: 413 additions & 69 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
- name: Run unit tests
3737
run: make testvv
3838
env:
39+
TEST_FLAGS: -race
3940

4041
- name: Report failures to Slack
4142
if: ${{ always() && github.ref == 'refs/heads/main' }}
@@ -47,4 +48,3 @@ jobs:
4748
notify_when: 'failure'
4849
env:
4950
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_REPORTING_WEBHOOK }}
50-
TEST_FLAGS: -race

client.go

Lines changed: 147 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ import (
55
"bytes"
66
"context"
77
"crypto/ed25519"
8-
"encoding/base64"
98
"encoding/json"
109
"fmt"
1110
"io"
11+
"net"
1212
"net/http"
13+
"sync/atomic"
1314
"time"
1415

1516
"github.com/DefinedNet/dnapi/message"
@@ -19,19 +20,40 @@ import (
1920

2021
// Client communicates with the API server.
2122
type Client struct {
22-
http *http.Client
2323
dnServer string
24+
25+
client *http.Client
26+
streamingClient *http.Client
2427
}
2528

2629
// NewClient returns new Client configured with the given useragent.
2730
// It also supports reading Proxy information from the environment.
2831
func NewClient(useragent string, dnServer string) *Client {
2932
return &Client{
30-
http: &http.Client{
33+
client: &http.Client{
3134
Timeout: 1 * time.Minute,
3235
Transport: &uaTransport{
3336
T: &http.Transport{
34-
Proxy: http.ProxyFromEnvironment,
37+
Proxy: http.ProxyFromEnvironment,
38+
TLSHandshakeTimeout: 10 * time.Second,
39+
ResponseHeaderTimeout: 10 * time.Second,
40+
DialContext: (&net.Dialer{
41+
Timeout: 10 * time.Second,
42+
}).DialContext,
43+
},
44+
useragent: useragent,
45+
},
46+
},
47+
streamingClient: &http.Client{
48+
Timeout: 15 * time.Minute,
49+
Transport: &uaTransport{
50+
T: &http.Transport{
51+
Proxy: http.ProxyFromEnvironment,
52+
TLSHandshakeTimeout: 10 * time.Second,
53+
ResponseHeaderTimeout: 10 * time.Second,
54+
DialContext: (&net.Dialer{
55+
Timeout: 10 * time.Second,
56+
}).DialContext,
3557
},
3658
useragent: useragent,
3759
},
@@ -40,9 +62,9 @@ func NewClient(useragent string, dnServer string) *Client {
4062
}
4163
}
4264

43-
// APIError contains an error, and a hidden wrapped error that contains the RequestID
44-
// contained in the X-Request-ID header of an API response. Defaults to empty string
45-
// if the header is not in the response.
65+
// APIError wraps an error and contains the RequestID from the X-Request-ID
66+
// header of an API response. ReqID defaults to empty string if the header is
67+
// not in the response.
4668
type APIError struct {
4769
e error
4870
ReqID string
@@ -67,12 +89,6 @@ type EnrollMeta struct {
6789
OrganizationName string
6890
}
6991

70-
func (c *Client) EnrollWithTimeout(ctx context.Context, t time.Duration, logger logrus.FieldLogger, code string) ([]byte, []byte, *Credentials, *EnrollMeta, error) {
71-
toCtx, cancel := context.WithTimeout(ctx, t)
72-
defer cancel()
73-
return c.Enroll(toCtx, logger, code)
74-
}
75-
7692
// Enroll issues an enrollment request against the REST API using the given enrollment code, passing along a locally
7793
// generated DH X25519 public key to be signed by the CA, and an Ed 25519 public key for future API call authentication.
7894
// On success it returns the Nebula config generated by the server, a Nebula private key PEM to be inserted into the
@@ -103,7 +119,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
103119
return nil, nil, nil, nil, err
104120
}
105121

106-
resp, err := c.http.Do(req)
122+
resp, err := c.client.Do(req)
107123
if err != nil {
108124
return nil, nil, nil, nil, err
109125
}
@@ -148,12 +164,6 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
148164
return r.Data.Config, dhPrivkeyPEM, creds, meta, nil
149165
}
150166

151-
func (c *Client) CheckForUpdateWithTimeout(ctx context.Context, t time.Duration, creds Credentials) (bool, error) {
152-
toCtx, cancel := context.WithTimeout(ctx, t)
153-
defer cancel()
154-
return c.CheckForUpdate(toCtx, creds)
155-
}
156-
157167
// CheckForUpdate sends a signed message to the DNClient API to learn if there is a new configuration available.
158168
func (c *Client) CheckForUpdate(ctx context.Context, creds Credentials) (bool, error) {
159169
respBody, err := c.postDNClient(ctx, message.CheckForUpdate, nil, creds.HostID, creds.Counter, creds.PrivateKey)
@@ -190,12 +200,6 @@ func (c *Client) LongPollWait(ctx context.Context, creds Credentials, supportedA
190200
return result.Data.Action, nil
191201
}
192202

193-
func (c *Client) DoUpdateWithTimeout(ctx context.Context, t time.Duration, creds Credentials) ([]byte, []byte, *Credentials, error) {
194-
toCtx, cancel := context.WithTimeout(ctx, t)
195-
defer cancel()
196-
return c.DoUpdate(toCtx, creds)
197-
}
198-
199203
// DoUpdate sends a signed message to the DNClient API to fetch the new configuration update. During this call a new
200204
// DH X25519 keypair is generated for the new Nebula certificate as well as a new Ed25519 keypair for DNClient API
201205
// communication. On success it returns the new config, a Nebula private key PEM to be inserted into the config (see
@@ -273,35 +277,85 @@ func (c *Client) DoUpdate(ctx context.Context, creds Credentials) ([]byte, []byt
273277
return result.Config, dhPrivkeyPEM, newCreds, nil
274278
}
275279

276-
// postDNClient wraps and signs the given dnclientRequestWrapper message, and makes the API call.
277-
// On success, it returns the response message body. On error, the error is returned.
278-
func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey ed25519.PrivateKey) ([]byte, error) {
279-
encMsg, err := json.Marshal(message.RequestWrapper{
280-
Type: reqType,
281-
Value: value,
282-
Timestamp: time.Now(),
280+
func (c *Client) StreamCommandResponse(ctx context.Context, creds Credentials, responseToken string) (*StreamController, error) {
281+
value, err := json.Marshal(message.CommandResponseRequest{
282+
ResponseToken: responseToken,
283283
})
284+
if err != nil {
285+
return nil, fmt.Errorf("failed to marshal DNClient message: %s", err)
286+
}
287+
288+
return c.streamingPostDNClient(ctx, message.CommandResponse, value, creds.HostID, creds.Counter, creds.PrivateKey)
289+
}
290+
291+
// streamingPostDNClient wraps and signs the given dnclientRequestWrapper message, and makes a streaming API call.
292+
// On success, it returns a StreamController to interact with the request. On error, the error is returned.
293+
func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey ed25519.PrivateKey) (*StreamController, error) {
294+
pr, pw := io.Pipe()
295+
296+
postBody, err := SignRequestV1(reqType, value, hostID, counter, privkey)
284297
if err != nil {
285298
return nil, err
286299
}
287-
signedMsg := base64.StdEncoding.EncodeToString(encMsg)
288-
sig := ed25519.Sign(privkey, []byte(signedMsg))
289-
body := message.RequestV1{
290-
Version: 1,
291-
HostID: hostID,
292-
Counter: counter,
293-
Message: signedMsg,
294-
Signature: sig,
300+
pbb := bytes.NewBuffer(postBody)
301+
302+
req, err := http.NewRequestWithContext(ctx, "POST", c.dnServer+message.EndpointV1, io.MultiReader(pbb, pr))
303+
if err != nil {
304+
return nil, err
295305
}
296-
postBody, err := json.Marshal(body)
306+
307+
done := make(chan struct{})
308+
sc := &StreamController{w: pw, done: done}
309+
310+
go func() {
311+
defer func() {
312+
close(done)
313+
}()
314+
315+
resp, err := c.streamingClient.Do(req)
316+
if err != nil {
317+
sc.err.Store(fmt.Errorf("failed to call dnclient endpoint: %s", err))
318+
return
319+
}
320+
defer resp.Body.Close()
321+
322+
respBody, err := io.ReadAll(resp.Body)
323+
if err != nil {
324+
sc.err.Store(fmt.Errorf("failed to read the response body: %s", err))
325+
}
326+
327+
switch resp.StatusCode {
328+
case http.StatusOK:
329+
sc.respBytes = respBody
330+
case http.StatusUnauthorized:
331+
sc.err.Store(InvalidCredentialsError{})
332+
default:
333+
var errors struct {
334+
Errors message.APIErrors
335+
}
336+
if err := json.Unmarshal(respBody, &errors); err != nil {
337+
sc.err.Store(fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody))
338+
}
339+
sc.err.Store(errors.Errors.ToError())
340+
}
341+
}()
342+
343+
return sc, nil
344+
}
345+
346+
// postDNClient wraps and signs the given dnclientRequestWrapper message, and makes the API call.
347+
// On success, it returns the response message body. On error, the error is returned.
348+
func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey ed25519.PrivateKey) ([]byte, error) {
349+
postBody, err := SignRequestV1(reqType, value, hostID, counter, privkey)
297350
if err != nil {
298351
return nil, err
299352
}
353+
300354
req, err := http.NewRequestWithContext(ctx, "POST", c.dnServer+message.EndpointV1, bytes.NewReader(postBody))
301355
if err != nil {
302356
return nil, err
303357
}
304-
resp, err := c.http.Do(req)
358+
resp, err := c.client.Do(req)
305359
if err != nil {
306360
return nil, fmt.Errorf("failed to call dnclient endpoint: %s", err)
307361
}
@@ -328,6 +382,56 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
328382
}
329383
}
330384

385+
// StreamController is used for interacting with streaming requests to the API.
386+
//
387+
// When a streaming request is started in a background goroutine, a StreamController is returned to the caller to allow
388+
// writing to the request body. The request will be sent when the caller closes the StreamController. The response body
389+
// can be read by calling ResponseBytes, which will block until the response is received.
390+
type StreamController struct {
391+
w *io.PipeWriter
392+
respBytes []byte
393+
err atomic.Value
394+
done chan struct{}
395+
}
396+
397+
// Err returns any error that occurred during the streaming request. If the request was successful, Err will return nil.
398+
// Err should be called after Close to ensure the request has completed.
399+
func (sc *StreamController) Err() error {
400+
err := sc.err.Load()
401+
if err == nil {
402+
return nil
403+
}
404+
return err.(error)
405+
}
406+
407+
// Write implements the io.Writer interface for StreamController. It writes to the request body. It never returns an
408+
// error. If the StreamController has already encountered an error, Write will return immediately without writing.
409+
// To check for errors, call Err.
410+
func (sc *StreamController) Write(p []byte) (n int, err error) {
411+
if sc.Err() != nil {
412+
return 0, sc.Err()
413+
}
414+
return sc.w.Write(p)
415+
}
416+
417+
// Close closes the StreamController, signaling that the request body is complete and the response can be read.
418+
func (sc *StreamController) Close() error {
419+
err := sc.w.Close()
420+
<-sc.done
421+
return err
422+
}
423+
424+
// ResponseBytes blocks until the response is received, then returns the response body. If an error occurred during the
425+
// request, ResponseBytes will return the error.
426+
func (sc *StreamController) ResponseBytes() ([]byte, error) {
427+
<-sc.done
428+
if sc.Err() != nil {
429+
return nil, sc.Err()
430+
}
431+
return sc.respBytes, nil
432+
}
433+
434+
// uaTransport wraps an http.RoundTripper and sets the User-Agent header on all requests.
331435
type uaTransport struct {
332436
useragent string
333437
T http.RoundTripper

0 commit comments

Comments
 (0)