88 "crypto/ed25519"
99 "crypto/rand"
1010 "encoding/json"
11+ "errors"
1112 "fmt"
1213 "io"
1314 "net"
@@ -134,119 +135,86 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
134135 }
135136
136137 // Make a request to the API with the enrollment code
137- jv , err := json . Marshal ( message.EnrollRequest {
138+ payload := message.EnrollRequest {
138139 Code : code ,
139140 NebulaPubkeyX25519 : newKeys .NebulaX25519PublicKeyPEM ,
140141 HostPubkeyEd25519 : hostEd25519PublicKeyPEM ,
141142 NebulaPubkeyP256 : newKeys .NebulaP256PublicKeyPEM ,
142143 HostPubkeyP256 : hostP256PublicKeyPEM ,
143144 Timestamp : time .Now (),
144- })
145- if err != nil {
146- return nil , nil , nil , nil , err
147- }
148-
149- enrollURL , err := urlPath (c .dnServer , message .EnrollEndpoint )
150- if err != nil {
151- return nil , nil , nil , nil , err
152- }
153-
154- req , err := http .NewRequestWithContext (ctx , "POST" , enrollURL , bytes .NewBuffer (jv ))
155- if err != nil {
156- return nil , nil , nil , nil , err
157- }
158-
159- resp , err := c .client .Do (req )
160- if err != nil {
161- return nil , nil , nil , nil , err
162- }
163- defer resp .Body .Close ()
164-
165- // Log the request ID returned from the server
166- reqID := resp .Header .Get ("X-Request-ID" )
167- l := logger .WithFields (logrus.Fields {"statusCode" : resp .StatusCode , "reqID" : reqID })
168- if resp .StatusCode == http .StatusOK {
169- l .Info ("Enrollment request returned success code" )
170- } else {
171- l .Error ("Enrollment request returned error code" )
172145 }
173146
174- // Decode the response
175- r := message.APIResponse [message.EnrollResponseData ]{}
176- b , err := io .ReadAll (resp .Body )
147+ reqID , r , err := callAPI [message.EnrollResponseData ](ctx , c , "POST" , message .EnrollEndpoint , payload )
148+ l := logger .WithFields (logrus.Fields {"reqID" : reqID })
177149 if err != nil {
178- return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("error reading response body: %s" , err ), ReqID : reqID }
179- }
180-
181- if err := json .Unmarshal (b , & r ); err != nil {
182- return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("error decoding JSON response: %s\n body: %s" , err , b ), ReqID : reqID }
183- }
184-
185- if len (r .Errors ) == 1 {
186- // Check for *only* an "invalid code" error returned by the API
187- if err := r .Errors [0 ]; err .Path == "code" && err .Code == "ERR_INVALID_VALUE" {
188- return nil , nil , nil , nil , & APIError {e : ErrInvalidCode , ReqID : reqID }
189- }
150+ var apiErrors message.APIErrors
151+ if errors .As (err , & apiErrors ) && len (apiErrors ) == 1 {
152+ // Check for *only* an "invalid code" error returned by the API
153+ if err := apiErrors [0 ]; err .Path == "code" && err .Code == "ERR_INVALID_VALUE" {
154+ l .Warn ("Enrollment request failed for invalid code" )
155+ return nil , nil , nil , nil , & APIError {e : ErrInvalidCode , ReqID : reqID }
156+ }
190157
191- // Check for *only* a blocked host error returned by the API
192- if err := r .Errors [0 ]; err .Path == "" && err .Code == "ERR_HOST_BLOCKED" {
193- return nil , nil , nil , nil , & APIError {e : ErrHostBlocked , ReqID : reqID }
158+ // Check for *only* a blocked host error returned by the API
159+ if err := apiErrors [0 ]; err .Path == "" && err .Code == "ERR_HOST_BLOCKED" {
160+ l .Warn ("Enrollment request failed for blocked host" )
161+ return nil , nil , nil , nil , & APIError {e : ErrHostBlocked , ReqID : reqID }
162+ }
194163 }
195- }
196164
197- // Check for any errors returned by the API
198- if err := r .Errors .ToError (); err != nil {
199- return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("unexpected error during enrollment: %v" , err ), ReqID : reqID }
165+ l .WithError (err ).Error ("Enrollment request failed with unexpected error" )
166+ return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("unexpected error during enrollment: %w" , err ), ReqID : reqID }
200167 }
168+ l .Info ("Enrollment request succeeded" )
201169
202170 meta := & ConfigMeta {
203171 Org : ConfigOrg {
204- ID : r .Data . Organization .ID ,
205- Name : r .Data . Organization .Name ,
172+ ID : r .Organization .ID ,
173+ Name : r .Organization .Name ,
206174 },
207175 Network : ConfigNetwork {
208- ID : r .Data . Network .ID ,
209- Name : r .Data . Network .Name ,
176+ ID : r .Network .ID ,
177+ Name : r .Network .Name ,
210178 },
211179 Host : ConfigHost {
212- ID : r .Data . HostID ,
213- Name : r .Data . Host .Name ,
214- IPAddress : r .Data . Host .IPAddress ,
180+ ID : r .HostID ,
181+ Name : r .Host .Name ,
182+ IPAddress : r .Host .IPAddress ,
215183 },
216184 }
217185
218- if r .Data . EndpointOIDCMeta != nil {
186+ if r .EndpointOIDCMeta != nil {
219187 meta .EndpointOIDC = & ConfigEndpointOIDC {
220- Email : r .Data . EndpointOIDCMeta .Email ,
188+ Email : r .EndpointOIDCMeta .Email ,
221189 }
222190 }
223191
224192 // Determine the private keys to save based on the network curve type
225193 var privkeyPEM []byte
226194 var privkey keys.PrivateKey
227- switch r .Data . Network .Curve {
195+ switch r .Network .Curve {
228196 case message .NetworkCurve25519 :
229197 privkeyPEM = newKeys .NebulaX25519PrivateKeyPEM
230198 privkey = newKeys .HostEd25519PrivateKey
231199 case message .NetworkCurveP256 :
232200 privkeyPEM = newKeys .NebulaP256PrivateKeyPEM
233201 privkey = newKeys .HostP256PrivateKey
234202 default :
235- return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("unsupported curve type: %s" , r .Data . Network .Curve ), ReqID : reqID }
203+ return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("unsupported curve type: %s" , r .Network .Curve ), ReqID : reqID }
236204 }
237205
238- trustedKeys , err := keys .TrustedKeysFromPEM (r .Data . TrustedKeys )
206+ trustedKeys , err := keys .TrustedKeysFromPEM (r .TrustedKeys )
239207 if err != nil {
240208 return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("failed to load trusted keys from bundle: %s" , err ), ReqID : reqID }
241209 }
242210
243211 creds := & keys.Credentials {
244- HostID : r .Data . HostID ,
212+ HostID : r .HostID ,
245213 PrivateKey : privkey ,
246- Counter : r .Data . Counter ,
214+ Counter : r .Counter ,
247215 TrustedKeys : trustedKeys ,
248216 }
249- return r .Data . Config , privkeyPEM , creds , meta , nil
217+ return r .Config , privkeyPEM , creds , meta , nil
250218}
251219
252220// CheckForUpdate sends a signed message to the DNClient API to learn if there is a new configuration available.
@@ -514,12 +482,12 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
514482 sc .err .Store (ErrInvalidCredentials )
515483 default :
516484 var errors struct {
517- Errors message.APIErrors
485+ Errors message.APIResponseErrors
518486 }
519487 if err := json .Unmarshal (respBody , & errors ); err != nil {
520488 sc .err .Store (fmt .Errorf ("dnclient endpoint returned bad status code '%d', body: %s" , resp .StatusCode , respBody ))
521489 } else {
522- sc .err .Store (errors .Errors .ToError ())
490+ sc .err .Store (errors .Errors .Err ())
523491 }
524492 }
525493 }()
@@ -561,38 +529,39 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
561529 return nil , ErrInvalidCredentials
562530 default :
563531 var errors struct {
564- Errors message.APIErrors
532+ Errors message.APIResponseErrors
565533 }
566534 if err := json .Unmarshal (respBody , & errors ); err != nil {
567535 return nil , fmt .Errorf ("dnclient endpoint returned bad status code '%d', body: %s" , resp .StatusCode , respBody )
568536 }
569- return nil , errors .Errors .ToError ()
537+ return nil , errors .Errors .Err ()
570538 }
571539}
572540
573- func callAPI [T any ](ctx context.Context , c * Client , method string , endpoint string , payload map [string ]any ) (* T , error ) {
541+ // callAPI returns the request ID, requested response data, and any error if applicable.
542+ func callAPI [T any ](ctx context.Context , c * Client , method string , endpoint string , payload any ) (string , * T , error ) {
574543 dest , err := urlPath (c .dnServer , endpoint )
575544 if err != nil {
576- return nil , err
545+ return "" , nil , err
577546 }
578547
579548 var br io.Reader
580549 if payload != nil {
581550 b , err := json .Marshal (payload )
582551 if err != nil {
583- return nil , fmt .Errorf ("failed to marshal payload: %s" , err )
552+ return "" , nil , fmt .Errorf ("failed to marshal payload: %s" , err )
584553 }
585554 br = bytes .NewReader (b )
586555 }
587556
588557 req , err := http .NewRequestWithContext (ctx , method , dest , br )
589558 if err != nil {
590- return nil , err
559+ return "" , nil , err
591560 }
592561
593562 resp , err := c .client .Do (req )
594563 if err != nil {
595- return nil , err
564+ return "" , nil , err
596565 }
597566 defer resp .Body .Close ()
598567
@@ -601,24 +570,24 @@ func callAPI[T any](ctx context.Context, c *Client, method string, endpoint stri
601570 r := message.APIResponse [T ]{}
602571 b , err := io .ReadAll (resp .Body )
603572 if err != nil {
604- return nil , & APIError {e : fmt .Errorf ("error reading response body: %s" , err ), ReqID : reqID }
573+ return reqID , nil , & APIError {e : fmt .Errorf ("error reading response body: %s" , err ), ReqID : reqID }
605574 }
606575
607576 if err := json .Unmarshal (b , & r ); err != nil {
608- return nil , & APIError {e : fmt .Errorf ("error decoding JSON response: %s\n body: %s" , err , b ), ReqID : reqID }
577+ return reqID , nil , & APIError {e : fmt .Errorf ("error decoding JSON response: %s\n body: %s" , err , b ), ReqID : reqID }
609578 }
610579
611580 // Check for any errors returned by the API
612- if err := r .Errors .ToError (); err != nil {
613- return nil , & APIError {e : err , ReqID : reqID }
581+ if err := r .Errors .Err (); err != nil {
582+ return reqID , nil , & APIError {e : err , ReqID : reqID }
614583 }
615584
616585 // If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error
617586 if resp .StatusCode >= 400 {
618- return nil , & APIError {e : fmt .Errorf ("received HTTP %d from API without error details\n body: %s" , resp .StatusCode , b ), ReqID : reqID }
587+ return reqID , nil , & APIError {e : fmt .Errorf ("received HTTP %d from API without error details\n body: %s" , resp .StatusCode , b ), ReqID : reqID }
619588 }
620589
621- return & r .Data , nil
590+ return reqID , & r .Data , nil
622591}
623592
624593// StreamController is used for interacting with streaming requests to the API.
@@ -694,12 +663,14 @@ func nonce() []byte {
694663}
695664
696665func (c * Client ) EndpointPreAuth (ctx context.Context ) (* message.PreAuthData , error ) {
697- return callAPI [message.PreAuthData ](ctx , c , "POST" , message .PreAuthEndpoint , nil )
666+ _ , d , err := callAPI [message.PreAuthData ](ctx , c , "POST" , message .PreAuthEndpoint , nil )
667+ return d , err
698668}
699669
700670func (c * Client ) EndpointAuthPoll (ctx context.Context , pollCode string ) (* message.EndpointAuthPollData , error ) {
701- pollURL := fmt .Sprintf ("%s?pollToken=%s" , message .EndpointAuthPoll , url .QueryEscape (pollCode ))
702- return callAPI [message.EndpointAuthPollData ](ctx , c , "GET" , pollURL , nil )
671+ pollURL := fmt .Sprintf ("%s?pollToken=%s" , message .AuthPollEndpoint , url .QueryEscape (pollCode ))
672+ _ , d , err := callAPI [message.EndpointAuthPollData ](ctx , c , "GET" , pollURL , nil )
673+ return d , err
703674}
704675
705676func urlPath (base , path string ) (string , error ) {
0 commit comments