Skip to content

Commit 53a8387

Browse files
authored
Use callAPI in Enroll call and expose API errors (#33)
1 parent bd115dd commit 53a8387

4 files changed

Lines changed: 91 additions & 108 deletions

File tree

client.go

Lines changed: 55 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"crypto/ed25519"
99
"crypto/rand"
1010
"encoding/json"
11+
"errors"
1112
"fmt"
1213
"io"
1314
"net"
@@ -134,119 +135,86 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
134135
}
135136

136137
// Make a request to the API with the enrollment code
137-
jv, err := json.Marshal(message.EnrollRequest{
138+
payload := message.EnrollRequest{
138139
Code: code,
139140
NebulaPubkeyX25519: newKeys.NebulaX25519PublicKeyPEM,
140141
HostPubkeyEd25519: hostEd25519PublicKeyPEM,
141142
NebulaPubkeyP256: newKeys.NebulaP256PublicKeyPEM,
142143
HostPubkeyP256: hostP256PublicKeyPEM,
143144
Timestamp: time.Now(),
144-
})
145-
if err != nil {
146-
return nil, nil, nil, nil, err
147-
}
148-
149-
enrollURL, err := urlPath(c.dnServer, message.EnrollEndpoint)
150-
if err != nil {
151-
return nil, nil, nil, nil, err
152-
}
153-
154-
req, err := http.NewRequestWithContext(ctx, "POST", enrollURL, bytes.NewBuffer(jv))
155-
if err != nil {
156-
return nil, nil, nil, nil, err
157-
}
158-
159-
resp, err := c.client.Do(req)
160-
if err != nil {
161-
return nil, nil, nil, nil, err
162-
}
163-
defer resp.Body.Close()
164-
165-
// Log the request ID returned from the server
166-
reqID := resp.Header.Get("X-Request-ID")
167-
l := logger.WithFields(logrus.Fields{"statusCode": resp.StatusCode, "reqID": reqID})
168-
if resp.StatusCode == http.StatusOK {
169-
l.Info("Enrollment request returned success code")
170-
} else {
171-
l.Error("Enrollment request returned error code")
172145
}
173146

174-
// Decode the response
175-
r := message.APIResponse[message.EnrollResponseData]{}
176-
b, err := io.ReadAll(resp.Body)
147+
reqID, r, err := callAPI[message.EnrollResponseData](ctx, c, "POST", message.EnrollEndpoint, payload)
148+
l := logger.WithFields(logrus.Fields{"reqID": reqID})
177149
if err != nil {
178-
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
179-
}
180-
181-
if err := json.Unmarshal(b, &r); err != nil {
182-
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
183-
}
184-
185-
if len(r.Errors) == 1 {
186-
// Check for *only* an "invalid code" error returned by the API
187-
if err := r.Errors[0]; err.Path == "code" && err.Code == "ERR_INVALID_VALUE" {
188-
return nil, nil, nil, nil, &APIError{e: ErrInvalidCode, ReqID: reqID}
189-
}
150+
var apiErrors message.APIErrors
151+
if errors.As(err, &apiErrors) && len(apiErrors) == 1 {
152+
// Check for *only* an "invalid code" error returned by the API
153+
if err := apiErrors[0]; err.Path == "code" && err.Code == "ERR_INVALID_VALUE" {
154+
l.Warn("Enrollment request failed for invalid code")
155+
return nil, nil, nil, nil, &APIError{e: ErrInvalidCode, ReqID: reqID}
156+
}
190157

191-
// Check for *only* a blocked host error returned by the API
192-
if err := r.Errors[0]; err.Path == "" && err.Code == "ERR_HOST_BLOCKED" {
193-
return nil, nil, nil, nil, &APIError{e: ErrHostBlocked, ReqID: reqID}
158+
// Check for *only* a blocked host error returned by the API
159+
if err := apiErrors[0]; err.Path == "" && err.Code == "ERR_HOST_BLOCKED" {
160+
l.Warn("Enrollment request failed for blocked host")
161+
return nil, nil, nil, nil, &APIError{e: ErrHostBlocked, ReqID: reqID}
162+
}
194163
}
195-
}
196164

197-
// Check for any errors returned by the API
198-
if err := r.Errors.ToError(); err != nil {
199-
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unexpected error during enrollment: %v", err), ReqID: reqID}
165+
l.WithError(err).Error("Enrollment request failed with unexpected error")
166+
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unexpected error during enrollment: %w", err), ReqID: reqID}
200167
}
168+
l.Info("Enrollment request succeeded")
201169

202170
meta := &ConfigMeta{
203171
Org: ConfigOrg{
204-
ID: r.Data.Organization.ID,
205-
Name: r.Data.Organization.Name,
172+
ID: r.Organization.ID,
173+
Name: r.Organization.Name,
206174
},
207175
Network: ConfigNetwork{
208-
ID: r.Data.Network.ID,
209-
Name: r.Data.Network.Name,
176+
ID: r.Network.ID,
177+
Name: r.Network.Name,
210178
},
211179
Host: ConfigHost{
212-
ID: r.Data.HostID,
213-
Name: r.Data.Host.Name,
214-
IPAddress: r.Data.Host.IPAddress,
180+
ID: r.HostID,
181+
Name: r.Host.Name,
182+
IPAddress: r.Host.IPAddress,
215183
},
216184
}
217185

218-
if r.Data.EndpointOIDCMeta != nil {
186+
if r.EndpointOIDCMeta != nil {
219187
meta.EndpointOIDC = &ConfigEndpointOIDC{
220-
Email: r.Data.EndpointOIDCMeta.Email,
188+
Email: r.EndpointOIDCMeta.Email,
221189
}
222190
}
223191

224192
// Determine the private keys to save based on the network curve type
225193
var privkeyPEM []byte
226194
var privkey keys.PrivateKey
227-
switch r.Data.Network.Curve {
195+
switch r.Network.Curve {
228196
case message.NetworkCurve25519:
229197
privkeyPEM = newKeys.NebulaX25519PrivateKeyPEM
230198
privkey = newKeys.HostEd25519PrivateKey
231199
case message.NetworkCurveP256:
232200
privkeyPEM = newKeys.NebulaP256PrivateKeyPEM
233201
privkey = newKeys.HostP256PrivateKey
234202
default:
235-
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unsupported curve type: %s", r.Data.Network.Curve), ReqID: reqID}
203+
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unsupported curve type: %s", r.Network.Curve), ReqID: reqID}
236204
}
237205

238-
trustedKeys, err := keys.TrustedKeysFromPEM(r.Data.TrustedKeys)
206+
trustedKeys, err := keys.TrustedKeysFromPEM(r.TrustedKeys)
239207
if err != nil {
240208
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("failed to load trusted keys from bundle: %s", err), ReqID: reqID}
241209
}
242210

243211
creds := &keys.Credentials{
244-
HostID: r.Data.HostID,
212+
HostID: r.HostID,
245213
PrivateKey: privkey,
246-
Counter: r.Data.Counter,
214+
Counter: r.Counter,
247215
TrustedKeys: trustedKeys,
248216
}
249-
return r.Data.Config, privkeyPEM, creds, meta, nil
217+
return r.Config, privkeyPEM, creds, meta, nil
250218
}
251219

252220
// CheckForUpdate sends a signed message to the DNClient API to learn if there is a new configuration available.
@@ -514,12 +482,12 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
514482
sc.err.Store(ErrInvalidCredentials)
515483
default:
516484
var errors struct {
517-
Errors message.APIErrors
485+
Errors message.APIResponseErrors
518486
}
519487
if err := json.Unmarshal(respBody, &errors); err != nil {
520488
sc.err.Store(fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody))
521489
} else {
522-
sc.err.Store(errors.Errors.ToError())
490+
sc.err.Store(errors.Errors.Err())
523491
}
524492
}
525493
}()
@@ -561,38 +529,39 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
561529
return nil, ErrInvalidCredentials
562530
default:
563531
var errors struct {
564-
Errors message.APIErrors
532+
Errors message.APIResponseErrors
565533
}
566534
if err := json.Unmarshal(respBody, &errors); err != nil {
567535
return nil, fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody)
568536
}
569-
return nil, errors.Errors.ToError()
537+
return nil, errors.Errors.Err()
570538
}
571539
}
572540

573-
func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload map[string]any) (*T, error) {
541+
// callAPI returns the request ID, requested response data, and any error if applicable.
542+
func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload any) (string, *T, error) {
574543
dest, err := urlPath(c.dnServer, endpoint)
575544
if err != nil {
576-
return nil, err
545+
return "", nil, err
577546
}
578547

579548
var br io.Reader
580549
if payload != nil {
581550
b, err := json.Marshal(payload)
582551
if err != nil {
583-
return nil, fmt.Errorf("failed to marshal payload: %s", err)
552+
return "", nil, fmt.Errorf("failed to marshal payload: %s", err)
584553
}
585554
br = bytes.NewReader(b)
586555
}
587556

588557
req, err := http.NewRequestWithContext(ctx, method, dest, br)
589558
if err != nil {
590-
return nil, err
559+
return "", nil, err
591560
}
592561

593562
resp, err := c.client.Do(req)
594563
if err != nil {
595-
return nil, err
564+
return "", nil, err
596565
}
597566
defer resp.Body.Close()
598567

@@ -601,24 +570,24 @@ func callAPI[T any](ctx context.Context, c *Client, method string, endpoint stri
601570
r := message.APIResponse[T]{}
602571
b, err := io.ReadAll(resp.Body)
603572
if err != nil {
604-
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
573+
return reqID, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
605574
}
606575

607576
if err := json.Unmarshal(b, &r); err != nil {
608-
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
577+
return reqID, nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
609578
}
610579

611580
// Check for any errors returned by the API
612-
if err := r.Errors.ToError(); err != nil {
613-
return nil, &APIError{e: err, ReqID: reqID}
581+
if err := r.Errors.Err(); err != nil {
582+
return reqID, nil, &APIError{e: err, ReqID: reqID}
614583
}
615584

616585
// If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error
617586
if resp.StatusCode >= 400 {
618-
return nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID}
587+
return reqID, nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID}
619588
}
620589

621-
return &r.Data, nil
590+
return reqID, &r.Data, nil
622591
}
623592

624593
// StreamController is used for interacting with streaming requests to the API.
@@ -694,12 +663,14 @@ func nonce() []byte {
694663
}
695664

696665
func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) {
697-
return callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil)
666+
_, d, err := callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil)
667+
return d, err
698668
}
699669

700670
func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
701-
pollURL := fmt.Sprintf("%s?pollToken=%s", message.EndpointAuthPoll, url.QueryEscape(pollCode))
702-
return callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil)
671+
pollURL := fmt.Sprintf("%s?pollToken=%s", message.AuthPollEndpoint, url.QueryEscape(pollCode))
672+
_, d, err := callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil)
673+
return d, err
703674
}
704675

705676
func urlPath(base, path string) (string, error) {

client_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func TestEnroll(t *testing.T) {
6565
})
6666
if err != nil {
6767
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
68-
Errors: message.APIErrors{{
68+
Errors: message.APIResponseErrors{{
6969
Code: "ERR_FAILED_TO_MARSHAL_YAML",
7070
Message: "failed to marshal test response config",
7171
}},
@@ -149,7 +149,7 @@ func TestEnroll(t *testing.T) {
149149
errorMsg := "invalid enrollment code"
150150
ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte {
151151
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
152-
Errors: message.APIErrors{{
152+
Errors: message.APIResponseErrors{{
153153
Code: "ERR_INVALID_ENROLLMENT_CODE",
154154
Message: errorMsg,
155155
}},
@@ -194,7 +194,7 @@ func TestDoUpdate(t *testing.T) {
194194
})
195195
if err != nil {
196196
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
197-
Errors: message.APIErrors{{
197+
Errors: message.APIResponseErrors{{
198198
Code: "ERR_FAILED_TO_MARSHAL_YAML",
199199
Message: "failed to marshal test response config",
200200
}},
@@ -463,7 +463,7 @@ func TestDoUpdate_P256(t *testing.T) {
463463
})
464464
if err != nil {
465465
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
466-
Errors: message.APIErrors{{
466+
Errors: message.APIResponseErrors{{
467467
Code: "ERR_FAILED_TO_MARSHAL_YAML",
468468
Message: "failed to marshal test response config",
469469
}},
@@ -557,7 +557,7 @@ func TestDoUpdate_P256(t *testing.T) {
557557
sig, err := nk.HostP256PrivateKey.Sign(rawRes)
558558
if err != nil {
559559
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
560-
Errors: message.APIErrors{{
560+
Errors: message.APIResponseErrors{{
561561
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
562562
Message: "failed to sign message",
563563
}},
@@ -601,7 +601,7 @@ func TestDoUpdate_P256(t *testing.T) {
601601
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
602602
if err != nil {
603603
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
604-
Errors: message.APIErrors{{
604+
Errors: message.APIResponseErrors{{
605605
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
606606
Message: "failed to sign message",
607607
}},
@@ -655,7 +655,7 @@ func TestDoUpdate_P256(t *testing.T) {
655655
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
656656
if err != nil {
657657
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
658-
Errors: message.APIErrors{{
658+
Errors: message.APIResponseErrors{{
659659
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
660660
Message: "failed to sign message",
661661
}},
@@ -703,7 +703,7 @@ func TestCommandResponse(t *testing.T) {
703703
})
704704
if err != nil {
705705
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
706-
Errors: message.APIErrors{{
706+
Errors: message.APIResponseErrors{{
707707
Code: "ERR_FAILED_TO_MARSHAL_YAML",
708708
Message: "failed to marshal test response config",
709709
}},
@@ -774,7 +774,7 @@ func TestCommandResponse(t *testing.T) {
774774
errorMsg := "sample error"
775775
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
776776
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
777-
Errors: message.APIErrors{{
777+
Errors: message.APIResponseErrors{{
778778
Code: "ERR_INVALID_VALUE",
779779
Message: errorMsg,
780780
}},
@@ -808,7 +808,7 @@ func TestStreamCommandResponse(t *testing.T) {
808808
})
809809
if err != nil {
810810
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
811-
Errors: message.APIErrors{{
811+
Errors: message.APIResponseErrors{{
812812
Code: "ERR_FAILED_TO_MARSHAL_YAML",
813813
Message: "failed to marshal test response config",
814814
}},
@@ -885,7 +885,7 @@ func TestStreamCommandResponse(t *testing.T) {
885885
errorMsg := "sample error"
886886
ts.ExpectStreamingRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
887887
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
888-
Errors: message.APIErrors{{
888+
Errors: message.APIResponseErrors{{
889889
Code: "ERR_INVALID_VALUE",
890890
Message: errorMsg,
891891
}},
@@ -934,7 +934,7 @@ func TestReauthenticate(t *testing.T) {
934934
})
935935
if err != nil {
936936
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
937-
Errors: message.APIErrors{{
937+
Errors: message.APIResponseErrors{{
938938
Code: "ERR_FAILED_TO_MARSHAL_YAML",
939939
Message: "failed to marshal test response config",
940940
}},
@@ -1094,7 +1094,7 @@ func TestGetOidcPollCode(t *testing.T) {
10941094
//unhappy path
10951095
ts.ExpectAPIRequest(http.StatusInternalServerError, func(req any) []byte {
10961096
return jsonMarshal(message.APIResponse[message.PreAuthData]{
1097-
Errors: message.APIErrors{{
1097+
Errors: message.APIResponseErrors{{
10981098
Code: "ERR_INTERNAL_SERVER_ERROR",
10991099
Message: "internal server error",
11001100
}},

dnapitest/dnapitest.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
7272
s.expectedRequests = s.expectedRequests[1:]
7373
w.WriteHeader(expected.StatusCode())
7474
_, _ = w.Write(expected.Respond(nil))
75-
case message.EndpointAuthPoll:
75+
case message.AuthPollEndpoint:
7676
s.handlerDoOidcPoll(w, r)
7777
default:
7878
s.errors = append(s.errors, fmt.Errorf("invalid request path %s", r.URL.Path))

0 commit comments

Comments
 (0)