@@ -10,14 +10,18 @@ import (
1010
1111 "github.com/stretchr/testify/assert"
1212 "github.com/stretchr/testify/require"
13+
14+ "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model"
1315)
1416
1517func TestNewRieInvokeRequest (t * testing.T ) {
1618 tests := []struct {
17- name string
18- request func () * http.Request
19- writer http.ResponseWriter
20- want * rieInvokeRequest
19+ name string
20+ request func () * http.Request
21+ writer http.ResponseWriter
22+ want * rieInvokeRequest
23+ wantError bool
24+ wantErrorContain string
2125 }{
2226 {
2327 name : "no_headers_in_request" ,
@@ -37,6 +41,7 @@ func TestNewRieInvokeRequest(t *testing.T) {
3741 cognitoIdentityPoolId : "" ,
3842 clientContext : "" ,
3943 },
44+ wantError : false ,
4045 },
4146 {
4247 name : "all_headers_present_in_request" ,
@@ -46,6 +51,7 @@ func TestNewRieInvokeRequest(t *testing.T) {
4651 r .Header .Set ("X-Amzn-Trace-Id" , "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9" )
4752 r .Header .Set ("X-Amz-Client-Context" , "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19" )
4853 r .Header .Set ("X-Amzn-RequestId" , "test-invoke-id" )
54+ r .Header .Set ("X-Amz-Cognito-Identity" , `{"cognitoIdentityId":"us-east-1:12345678-1234-1234-1234-123456789012","cognitoIdentityPoolId":"us-east-1:87654321-4321-4321-4321-210987654321"}` )
4955 require .NoError (t , err )
5056 return r
5157 },
@@ -57,16 +63,80 @@ func TestNewRieInvokeRequest(t *testing.T) {
5763 responseBandwidthRate : 2 * 1024 * 1024 ,
5864 responseBandwidthBurstSize : 6 * 1024 * 1024 ,
5965 traceId : "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9" ,
60- cognitoIdentityId : "" ,
66+ cognitoIdentityId : "us-east-1:12345678-1234-1234-1234-123456789012" ,
67+ cognitoIdentityPoolId : "us-east-1:87654321-4321-4321-4321-210987654321" ,
68+ clientContext : `{"custom":{"test":"value"}}` ,
69+ },
70+ wantError : false ,
71+ },
72+ {
73+ name : "malformed_cognito_identity_header" ,
74+ request : func () * http.Request {
75+ r , err := http .NewRequest ("GET" , "http://localhost/" , nil )
76+ r .Header .Set ("X-Amzn-RequestId" , "test-invoke-id" )
77+ r .Header .Set ("X-Amz-Cognito-Identity" , "not-valid-json{" )
78+ require .NoError (t , err )
79+ return r
80+ },
81+ writer : httptest .NewRecorder (),
82+ want : nil ,
83+ wantError : true ,
84+ wantErrorContain : "X-Amz-Cognito-Identity must be a valid JSON string" ,
85+ },
86+ {
87+ name : "malformed_client_context_header" ,
88+ request : func () * http.Request {
89+ r , err := http .NewRequest ("GET" , "http://localhost/" , nil )
90+ r .Header .Set ("X-Amzn-RequestId" , "test-invoke-id" )
91+ r .Header .Set ("X-Amz-Client-Context" , "not-valid-base64!!!" )
92+ require .NoError (t , err )
93+ return r
94+ },
95+ writer : httptest .NewRecorder (),
96+ want : nil ,
97+ wantError : true ,
98+ wantErrorContain : "X-Amz-Client-Context must be a valid base64 encoded string" ,
99+ },
100+ {
101+ name : "partial_cognito_identity_header" ,
102+ request : func () * http.Request {
103+ r , err := http .NewRequest ("GET" , "http://localhost/" , nil )
104+ r .Header .Set ("X-Amzn-RequestId" , "test-invoke-id" )
105+ r .Header .Set ("X-Amz-Cognito-Identity" , `{"cognitoIdentityId":"us-east-1:only-id"}` )
106+ require .NoError (t , err )
107+ return r
108+ },
109+ writer : httptest .NewRecorder (),
110+ want : & rieInvokeRequest {
111+ invokeID : "test-invoke-id" ,
112+ contentType : "application/json" ,
113+ maxPayloadSize : 6 * 1024 * 1024 + 100 ,
114+ responseBandwidthRate : 2 * 1024 * 1024 ,
115+ responseBandwidthBurstSize : 6 * 1024 * 1024 ,
116+ traceId : "" ,
117+ cognitoIdentityId : "us-east-1:only-id" ,
61118 cognitoIdentityPoolId : "" ,
62- clientContext : "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19 " ,
119+ clientContext : "" ,
63120 },
121+ wantError : false ,
64122 },
65123 }
66124 for _ , tt := range tests {
67125 t .Run (tt .name , func (t * testing.T ) {
68126 r := tt .request ()
69- got := NewRieInvokeRequest (r , tt .writer )
127+ got , err := NewRieInvokeRequest (r , tt .writer )
128+
129+ if tt .wantError {
130+ assert .NotNil (t , err )
131+ assert .Nil (t , got )
132+ assert .Equal (t , model .ErrorMalformedRequest , err .ErrorType ())
133+ assert .Equal (t , http .StatusBadRequest , err .ReturnCode ())
134+ assert .Contains (t , err .Error (), tt .wantErrorContain )
135+ return
136+ }
137+
138+ assert .Nil (t , err )
139+ require .NotNil (t , got )
70140
71141 tt .want .request = r
72142 tt .want .writer = tt .writer
0 commit comments