diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java index e96403e48..419f40847 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java @@ -6,7 +6,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; import io.modelcontextprotocol.util.Assert; @@ -47,27 +46,18 @@ private DefaultServerTransportSecurityValidator(List allowedOrigins, Lis } @Override - public void validateHeaders(Map> headers) throws ServerTransportSecurityException { - boolean missingHost = true; - for (Map.Entry> entry : headers.entrySet()) { - if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) { - List values = entry.getValue(); - if (values == null || values.isEmpty()) { - throw new ServerTransportSecurityException(403, "Invalid Origin header"); - } - validateOrigin(values.get(0)); - } - else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) { - missingHost = false; - List values = entry.getValue(); - if (values == null || values.isEmpty()) { - throw new ServerTransportSecurityException(421, "Invalid Host header"); - } - validateHost(values.get(0)); - } + public void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException { + List originValues = accessor.getHeader(ORIGIN_HEADER); + if (originValues != null && !originValues.isEmpty()) { + validateOrigin(originValues.get(0)); } - if (!allowedHosts.isEmpty() && missingHost) { - throw new ServerTransportSecurityException(421, "Invalid Host header"); + + if (!allowedHosts.isEmpty()) { + List hostValues = accessor.getHeader(HOST_HEADER); + if (hostValues == null || hostValues.isEmpty()) { + throw new ServerTransportSecurityException(421, "Invalid Host header"); + } + validateHost(hostValues.get(0)); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HeaderAccessor.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HeaderAccessor.java new file mode 100644 index 000000000..cdad84de5 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HeaderAccessor.java @@ -0,0 +1,34 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.List; + +/** + * Abstraction for accessing HTTP headers from an incoming request. Implementations should + * provide case-insensitive header name lookups (e.g., when backed by + * {@code HttpServletRequest}). + * + * @author Neeraj Bhatt + * @since 0.16.0 + * @see ServerTransportSecurityValidator + */ +public interface HeaderAccessor { + + /** + * Returns the values of the specified header, or an empty list if the header is not + * present. + * @param name the header name (case-insensitive) + * @return the list of header values, never {@code null} + */ + List getHeader(String name); + + /** + * Returns all header names present in the request. + * @return the list of header names, never {@code null} + */ + List getHeaderNames(); + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletHeaderAccessor.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletHeaderAccessor.java new file mode 100644 index 000000000..1cc92f010 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletHeaderAccessor.java @@ -0,0 +1,41 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.util.Collections; +import java.util.List; + +import jakarta.servlet.http.HttpServletRequest; + +/** + * {@link HeaderAccessor} implementation backed by an {@link HttpServletRequest}. Header + * name lookups are case-insensitive as per the Servlet specification. + * + *

+ * For internal use only. + * + * @author Neeraj Bhatt + * @since 0.16.0 + * @see HeaderAccessor + */ +final class HttpServletHeaderAccessor implements HeaderAccessor { + + private final HttpServletRequest request; + + HttpServletHeaderAccessor(HttpServletRequest request) { + this.request = request; + } + + @Override + public List getHeader(String name) { + return Collections.list(this.request.getHeaders(name)); + } + + @Override + public List getHeaderNames() { + return Collections.list(this.request.getHeaderNames()); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java deleted file mode 100644 index 32246948c..000000000 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2026-2026 the original author or authors. - */ - -package io.modelcontextprotocol.server.transport; - -import java.util.Collections; -import java.util.Enumeration; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import jakarta.servlet.http.HttpServletRequest; - -/** - * Utility methods for working with {@link HttpServletRequest}. For internal use only. - * - * @author Daniel Garnier-Moiroux - */ -final class HttpServletRequestUtils { - - private HttpServletRequestUtils() { - } - - /** - * Extracts all headers from the HTTP request into a map. - * @param request The HTTP servlet request - * @return A map of header names to their values - */ - static Map> extractHeaders(HttpServletRequest request) { - Map> headers = new HashMap<>(); - Enumeration names = request.getHeaderNames(); - while (names.hasMoreElements()) { - String name = names.nextElement(); - headers.put(name, Collections.list(request.getHeaders(name))); - } - return headers; - } - -} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 69d73f7ab..409174966 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -280,8 +280,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request)); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); @@ -353,8 +352,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request)); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 047aeebe8..96e6f33df 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -8,7 +8,6 @@ import java.io.IOException; import java.io.PrintWriter; import java.util.List; -import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -134,8 +133,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request)); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 9a785e150..1a77e16cf 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -10,7 +10,6 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; @@ -271,8 +270,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request)); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); @@ -407,8 +405,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request)); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); @@ -588,8 +585,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response } try { - Map> headers = HttpServletRequestUtils.extractHeaders(request); - this.securityValidator.validateHeaders(headers); + this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request)); } catch (ServerTransportSecurityException e) { response.sendError(e.getStatusCode(), e.getMessage()); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java index ce805931f..7a649f875 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java @@ -4,33 +4,101 @@ package io.modelcontextprotocol.server.transport; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; /** * Interface for validating HTTP requests in server transports. Implementations can * validate Origin headers, Host headers, or any other security-related headers according * to the MCP specification. * + *

+ * New implementations should override {@link #validateHeaders(HeaderAccessor) + * validateHeaders(HeaderAccessor)} for more efficient, case-insensitive header access. + * The older {@link #validateHeaders(Map) validateHeaders(Map)} is deprecated and will be + * removed in a future major version. + * * @author Daniel Garnier-Moiroux * @see DefaultServerTransportSecurityValidator * @see ServerTransportSecurityException */ -@FunctionalInterface public interface ServerTransportSecurityValidator { /** * A no-op validator that accepts all requests without validation. */ - ServerTransportSecurityValidator NOOP = headers -> { + ServerTransportSecurityValidator NOOP = new ServerTransportSecurityValidator() { + @Override + public void validateHeaders(Map> headers) throws ServerTransportSecurityException { + } + + @Override + public void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException { + } }; /** * Validates the HTTP headers from an incoming request. + * + *

+ * The default implementation converts the map into a {@link HeaderAccessor} and + * delegates to {@link #validateHeaders(HeaderAccessor)}. * @param headers A map of header names to their values (multi-valued headers * supported) * @throws ServerTransportSecurityException if validation fails + * @deprecated Use {@link #validateHeaders(HeaderAccessor)} instead for more + * efficient, case-insensitive header access. This method will be removed in a future + * major version. + */ + @Deprecated + default void validateHeaders(Map> headers) throws ServerTransportSecurityException { + validateHeaders(new HeaderAccessor() { + @Override + public List getHeader(String name) { + return headers.entrySet() + .stream() + .filter(e -> e.getKey().equalsIgnoreCase(name)) + .map(Map.Entry::getValue) + .findFirst() + .orElse(List.of()); + } + + @Override + public List getHeaderNames() { + return List.copyOf(headers.keySet()); + } + }); + } + + /** + * Validates the HTTP headers from an incoming request using a {@link HeaderAccessor}. + * + *

+ * New implementations should override this method. Header name lookup through the + * accessor should be case-insensitive (e.g., when backed by + * {@code HttpServletRequest}). + * + *

+ * The default implementation collects all headers from the accessor into a + * {@link Map} and delegates to the deprecated {@link #validateHeaders(Map)} method, + * so that existing implementations that only override {@link #validateHeaders(Map)} + * continue to work. + * @param accessor provides access to request headers + * @throws ServerTransportSecurityException if validation fails */ - void validateHeaders(Map> headers) throws ServerTransportSecurityException; + default void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException { + var collectedHeaders = accessor.getHeaderNames() + .stream() + .collect(Collectors.>toUnmodifiableMap(String::toLowerCase, + accessor::getHeader, (l1, l2) -> { + var merged = new ArrayList<>(l1); + merged.addAll(l2); + return Collections.unmodifiableList(merged); + })); + validateHeaders(collectedHeaders); + } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java index d4cf8582d..a37414702 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java @@ -43,50 +43,50 @@ class OriginHeader { @Test void originHeaderMissing() { - assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(emptyAccessor())).doesNotThrowAnyException(); } @Test void originHeaderListEmpty() { - assertThatThrownBy(() -> validator.validateHeaders(Map.of("Origin", List.of()))).isEqualTo(INVALID_ORIGIN); + assertThatCode(() -> validator.validateHeaders(headerAccessor())).doesNotThrowAnyException(); } @Test void caseInsensitive() { - var headers = Map.of("origin", List.of("http://localhost:8080")); + var accessor = headerAccessor("Origin", "http://localhost:8080"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void exactMatch() { - var headers = originHeader("http://localhost:8080"); + var accessor = originAccessor("http://localhost:8080"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPort() { - var headers = originHeader("http://localhost:3000"); + var accessor = originAccessor("http://localhost:3000"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentHost() { - var headers = originHeader("http://example.com:8080"); + var accessor = originAccessor("http://example.com:8080"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentScheme() { - var headers = originHeader("https://localhost:8080"); + var accessor = originAccessor("https://localhost:8080"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Nested @@ -99,37 +99,37 @@ class WildcardPort { @Test void anyPortWithWildcard() { - var headers = originHeader("http://localhost:3000"); + var accessor = originAccessor("http://localhost:3000"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void noPortWithWildcard() { - var headers = originHeader("http://localhost"); + var accessor = originAccessor("http://localhost"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPortWithWildcard() { - var headers = originHeader("http://localhost:8080"); + var accessor = originAccessor("http://localhost:8080"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHostWithWildcard() { - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> wildcardValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentSchemeWithWildcard() { - var headers = originHeader("https://localhost:3000"); + var accessor = originAccessor("https://localhost:3000"); - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> wildcardValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } } @@ -146,23 +146,23 @@ class MultipleOrigins { @Test void matchingOneOfMultiple() { - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleOriginsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void matchingWildcardInMultiple() { - var headers = originHeader("http://myapp.example.com:9999"); + var accessor = originAccessor("http://myapp.example.com:9999"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleOriginsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void notMatchingAny() { - var headers = originHeader("http://malicious.example.com:1234"); + var accessor = originAccessor("http://malicious.example.com:1234"); - assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } } @@ -176,9 +176,9 @@ void shouldAddMultipleOriginsWithAllowedOriginsMethod() { .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*")) .build(); - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test @@ -188,11 +188,11 @@ void shouldCombineAllowedOriginMethods() { .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000")) .build(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://localhost:8080"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://example.com:9999"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://test.com:3000"))) .doesNotThrowAnyException(); } @@ -210,45 +210,45 @@ class HostHeader { @Test void notConfigured() { - assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(emptyAccessor())).doesNotThrowAnyException(); } @Test void missing() { - assertThatThrownBy(() -> hostValidator.validateHeaders(new HashMap<>())).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(emptyAccessor())).isEqualTo(INVALID_HOST); } @Test void listEmpty() { - assertThatThrownBy(() -> hostValidator.validateHeaders(Map.of("Host", List.of()))).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(headerAccessor())).isEqualTo(INVALID_HOST); } @Test void caseInsensitive() { - var headers = Map.of("host", List.of("localhost:8080")); + var accessor = headerAccessor("Host", "localhost:8080"); - assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> hostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void exactMatch() { - var headers = hostHeader("localhost:8080"); + var accessor = hostAccessor("localhost:8080"); - assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> hostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPort() { - var headers = hostHeader("localhost:3000"); + var accessor = hostAccessor("localhost:3000"); - assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void differentHost() { - var headers = hostHeader("example.com:8080"); + var accessor = hostAccessor("example.com:8080"); - assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Nested @@ -261,23 +261,23 @@ class HostWildcardPort { @Test void anyPort() { - var headers = hostHeader("localhost:3000"); + var accessor = hostAccessor("localhost:3000"); - assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardHostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void noPort() { - var headers = hostHeader("localhost"); + var accessor = hostAccessor("localhost"); - assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardHostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHost() { - var headers = hostHeader("example.com:3000"); + var accessor = hostAccessor("example.com:3000"); - assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } } @@ -293,30 +293,30 @@ class MultipleHosts { @Test void exactMatch() { - var headers = hostHeader("example.com:3000"); + var accessor = hostAccessor("example.com:3000"); - assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleHostsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void wildcard() { - var headers = hostHeader("myapp.example.com:9999"); + var accessor = hostAccessor("myapp.example.com:9999"); - assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleHostsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHost() { - var headers = hostHeader("malicious.example.com:3000"); + var accessor = hostAccessor("malicious.example.com:3000"); - assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void differentPort() { - var headers = hostHeader("localhost:8080"); + var accessor = hostAccessor("localhost:8080"); - assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } } @@ -330,9 +330,9 @@ void multipleHosts() { .allowedHosts(List.of("localhost:8080", "example.com:*")) .build(); - assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:3000"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("example.com:3000"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("localhost:8080"))) .doesNotThrowAnyException(); } @@ -343,11 +343,12 @@ void combined() { .allowedHosts(List.of("example.com:*", "test.com:3000")) .build(); - assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("localhost:8080"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:9999"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("example.com:9999"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostAccessor("test.com:3000"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("test.com:3000"))).doesNotThrowAnyException(); } } @@ -365,60 +366,222 @@ class CombinedOriginAndHostValidation { @Test void bothValid() { - var header = headers("http://localhost:8080", "localhost:8080"); + var accessor = combinedAccessor("http://localhost:8080", "localhost:8080"); - assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void originValidHostInvalid() { - var header = headers("http://localhost:8080", "malicious.example.com:8080"); + var accessor = combinedAccessor("http://localhost:8080", "malicious.example.com:8080"); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void originInvalidHostValid() { - var header = headers("http://malicious.example.com:8080", "localhost:8080"); + var accessor = combinedAccessor("http://malicious.example.com:8080", "localhost:8080"); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void originMissingHostValid() { // Origin missing is OK (same-origin request) - var header = headers(null, "localhost:8080"); + var accessor = combinedAccessor(null, "localhost:8080"); - assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void originValidHostMissing() { // Host missing is NOT OK when allowedHosts is configured - var header = headers("http://localhost:8080", null); + var accessor = combinedAccessor("http://localhost:8080", null); + + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); + } + + } + + @Nested + class DeprecatedMapBasedApi { + + @Test + void originValidation() { + Map> headers = new HashMap<>(); + headers.put("Origin", List.of("http://localhost:8080")); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void originRejected() { + Map> headers = new HashMap<>(); + headers.put("Origin", List.of("http://malicious.example.com")); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void caseInsensitiveHeaderLookup() { + Map> headers = new HashMap<>(); + headers.put("origin", List.of("http://localhost:8080")); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void hostValidation() { + DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator.builder() + .allowedHost("localhost:8080") + .build(); + + Map> headers = new HashMap<>(); + headers.put("Host", List.of("localhost:8080")); + + assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void hostRejected() { + DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator.builder() + .allowedHost("localhost:8080") + .build(); + + Map> headers = new HashMap<>(); + headers.put("Host", List.of("malicious.com:8080")); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + @Test + void emptyHeaders() { + assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + } + + @Test + void combinedOriginAndHost() { + DefaultServerTransportSecurityValidator combinedValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build(); + + Map> headers = new HashMap<>(); + headers.put("Origin", List.of("http://localhost:8080")); + headers.put("Host", List.of("localhost:8080")); + + assertThatCode(() -> combinedValidator.validateHeaders(headers)).doesNotThrowAnyException(); } } - private static Map> originHeader(String origin) { - return Map.of("Origin", List.of(origin)); + @Nested + class InterfaceDefaultBridge { + + @Test + void noopAcceptsAll() { + assertThatCode(() -> ServerTransportSecurityValidator.NOOP.validateHeaders(emptyAccessor())) + .doesNotThrowAnyException(); + assertThatCode(() -> ServerTransportSecurityValidator.NOOP.validateHeaders(new HashMap<>())) + .doesNotThrowAnyException(); + } + + @Test + void mapDefaultBridgesToAccessorOverride() { + // A validator that only overrides the HeaderAccessor method should still work + // when called via the deprecated Map method + ServerTransportSecurityValidator accessorOnlyValidator = new ServerTransportSecurityValidator() { + @Override + public void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException { + List origins = accessor.getHeader("Origin"); + if (origins != null && !origins.isEmpty() && origins.get(0).contains("evil")) { + throw new ServerTransportSecurityException(403, "Invalid Origin header"); + } + } + }; + + Map> goodHeaders = new HashMap<>(); + goodHeaders.put("Origin", List.of("http://good.example.com")); + assertThatCode(() -> accessorOnlyValidator.validateHeaders(goodHeaders)).doesNotThrowAnyException(); + + Map> evilHeaders = new HashMap<>(); + evilHeaders.put("Origin", List.of("http://evil.example.com")); + assertThatThrownBy(() -> accessorOnlyValidator.validateHeaders(evilHeaders)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void accessorDefaultBridgesToMapOverride() { + // A validator that only overrides the deprecated Map method should still work + // when called via the new HeaderAccessor method + ServerTransportSecurityValidator mapOnlyValidator = new ServerTransportSecurityValidator() { + @Override + public void validateHeaders(Map> headers) throws ServerTransportSecurityException { + List origins = headers.getOrDefault("origin", List.of()); + if (!origins.isEmpty() && origins.get(0).contains("evil")) { + throw new ServerTransportSecurityException(403, "Invalid Origin header"); + } + } + }; + + assertThatCode(() -> mapOnlyValidator.validateHeaders(originAccessor("http://good.example.com"))) + .doesNotThrowAnyException(); + + assertThatThrownBy(() -> mapOnlyValidator.validateHeaders(originAccessor("http://evil.example.com"))) + .isEqualTo(INVALID_ORIGIN); + } + + } + + private static HeaderAccessor emptyAccessor() { + return headerAccessor(); + } + + private static HeaderAccessor headerAccessor(String... namesAndValues) { + Map> headers = new HashMap<>(); + for (int i = 0; i < namesAndValues.length; i += 2) { + headers.put(namesAndValues[i], List.of(namesAndValues[i + 1])); + } + return new HeaderAccessor() { + @Override + public List getHeader(String name) { + return headers.getOrDefault(name, List.of()); + } + + @Override + public List getHeaderNames() { + return List.copyOf(headers.keySet()); + } + }; } - private static Map> hostHeader(String host) { - return Map.of("Host", List.of(host)); + private static HeaderAccessor originAccessor(String origin) { + return headerAccessor("Origin", origin); } - private static Map> headers(String origin, String host) { - var map = new HashMap>(); + private static HeaderAccessor hostAccessor(String host) { + return headerAccessor("Host", host); + } + + private static HeaderAccessor combinedAccessor(String origin, String host) { + Map> headers = new HashMap<>(); if (origin != null) { - map.put("Origin", List.of(origin)); + headers.put("Origin", List.of(origin)); } if (host != null) { - map.put("Host", List.of(host)); + headers.put("Host", List.of(host)); } - return map; + return new HeaderAccessor() { + @Override + public List getHeader(String name) { + return headers.getOrDefault(name, List.of()); + } + + @Override + public List getHeaderNames() { + return List.copyOf(headers.keySet()); + } + }; } }