@@ -2,11 +2,10 @@ package graphql.servlet
22
33import com.fasterxml.jackson.databind.ObjectMapper
44import graphql.Scalars
5- import graphql.annotations.annotationTypes.GraphQLType
65import graphql.execution.ExecutionStepInfo
76import graphql.execution.instrumentation.ChainedInstrumentation
8-
97import graphql.execution.instrumentation.Instrumentation
8+ import graphql.execution.reactive.SingleSubscriberPublisher
109import graphql.schema.GraphQLNonNull
1110import org.dataloader.DataLoaderRegistry
1211import org.springframework.mock.web.MockHttpServletRequest
@@ -17,6 +16,9 @@ import spock.lang.Specification
1716
1817import javax.servlet.ServletInputStream
1918import javax.servlet.http.HttpServletRequest
19+ import java.util.concurrent.CountDownLatch
20+ import java.util.concurrent.TimeUnit
21+ import java.util.concurrent.atomic.AtomicReference
2022
2123/**
2224 * @author Andrew Potter
@@ -27,17 +29,32 @@ class AbstractGraphQLHttpServletSpec extends Specification {
2729 public static final int STATUS_BAD_REQUEST = 400
2830 public static final int STATUS_ERROR = 500
2931 public static final String CONTENT_TYPE_JSON_UTF8 = ' application/json;charset=UTF-8'
32+ public static final String CONTENT_TYPE_SERVER_SENT_EVENTS = ' text/event-stream;charset=UTF-8'
3033
3134 @Shared
3235 ObjectMapper mapper = new ObjectMapper ()
3336
3437 AbstractGraphQLHttpServlet servlet
3538 MockHttpServletRequest request
3639 MockHttpServletResponse response
40+ CountDownLatch subscriptionLatch
3741
3842 def setup () {
39- servlet = TestUtils . createServlet()
43+ subscriptionLatch = new CountDownLatch (1 )
44+ servlet = TestUtils . createServlet({ env -> env. arguments. arg }, { env -> env. arguments. arg }, { env ->
45+ AtomicReference<SingleSubscriberPublisher<String > > publisherRef = new AtomicReference<> ()
46+ publisherRef. set(new SingleSubscriberPublisher<String > ({
47+ SingleSubscriberPublisher<String > publisher = publisherRef. get()
48+ publisher. offer(" First\n\n " + env. arguments. arg)
49+ publisher. offer(" Second\n\n " + env. arguments. arg)
50+ publisher. noMoreData()
51+ subscriptionLatch. countDown()
52+ }))
53+ return publisherRef. get()
54+ })
55+
4056 request = new MockHttpServletRequest ()
57+ request. asyncSupported = true
4158 response = new MockHttpServletResponse ()
4259 }
4360
@@ -46,6 +63,17 @@ class AbstractGraphQLHttpServletSpec extends Specification {
4663 mapper. readValue(response. getContentAsByteArray(), Map )
4764 }
4865
66+ List<Map<String , Object > > getSubscriptionResponseContent () {
67+ String [] data = response. getContentAsString(). split(" \n\n " )
68+ return data. collect { dataLine ->
69+ if (dataLine. startsWith(" data: " )) {
70+ return mapper. readValue(dataLine. substring(5 ), Map )
71+ } else {
72+ throw new IllegalStateException (" Could not read event stream" )
73+ }
74+ }
75+ }
76+
4977 List<Map<String , Object > > getBatchedResponseContent () {
5078 mapper. readValue(response. getContentAsByteArray(), List )
5179 }
@@ -263,6 +291,26 @@ class AbstractGraphQLHttpServletSpec extends Specification {
263291 getBatchedResponseContent()[1 ]. errors. size() == 1
264292 }
265293
294+ def " subscription query over HTTP GET with variables as string returns data" () {
295+ setup :
296+ request. addParameter(' query' , ' subscription Subscription($arg: String!) { echo(arg: $arg) }' )
297+ request. addParameter(' operationName' , ' Subscription' )
298+ request. addParameter( ' variables' , ' {"arg": "test"}' )
299+ request. setAsyncSupported(true )
300+
301+ when :
302+ servlet. doGet(request, response)
303+ then :
304+ response. getStatus() == STATUS_OK
305+ response. getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS
306+
307+ when :
308+ subscriptionLatch. await(1 , TimeUnit . SECONDS )
309+ then :
310+ getSubscriptionResponseContent()[0 ]. data. echo == " First\n\n test"
311+ getSubscriptionResponseContent()[1 ]. data. echo == " Second\n\n test"
312+ }
313+
266314 def " query over HTTP POST without part or body returns bad request" () {
267315 when :
268316 servlet. doPost(request, response)
@@ -903,6 +951,24 @@ class AbstractGraphQLHttpServletSpec extends Specification {
903951 getBatchedResponseContent()[1 ]. data. echo == " test"
904952 }
905953
954+ def " subscription query over HTTP POST with variables as string returns data" () {
955+ setup :
956+ request. setContent(' {"query": "subscription Subscription($arg: String!) { echo(arg: $arg) }", "operationName": "Subscription", "variables": {"arg": "test"}}' . bytes)
957+ request. setAsyncSupported(true )
958+
959+ when :
960+ servlet. doPost(request, response)
961+ then :
962+ response. getStatus() == STATUS_OK
963+ response. getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS
964+
965+ when :
966+ subscriptionLatch. await(1 , TimeUnit . SECONDS )
967+ then :
968+ getSubscriptionResponseContent()[0 ]. data. echo == " First\n\n test"
969+ getSubscriptionResponseContent()[1 ]. data. echo == " Second\n\n test"
970+ }
971+
906972 def " errors before graphql schema execution return internal server error" () {
907973 setup :
908974 servlet = SimpleGraphQLHttpServlet . newBuilder(GraphQLInvocationInputFactory . newBuilder {
0 commit comments