Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mongo/transport/proxy_protocol_header_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ void parseSubTLVVectors(StringData buffer, boost::optional<ProxiedSSLData>& sslT
void parseTLVVectors(StringData buffer,
std::vector<ProxiedSupplementaryDataEntry>& tlvs,
boost::optional<ProxiedSSLData>& sslTlvs) {
size_t tlvCount = 0;
while (buffer.size()) {
static constexpr size_t kTLVHeaderSize = 3;
uassert(ErrorCodes::FailedToParse,
Expand All @@ -230,6 +231,11 @@ void parseTLVVectors(StringData buffer,
buffer.size(),
kTLVHeaderSize),
buffer.size() > kTLVHeaderSize);
uassert(ErrorCodes::FailedToParse,
fmt::format("Proxy Protocol Version 2 TLV entry count exceeds {}",
kMaxProxyProtocolTLVEntriesPerVector),
tlvCount < kMaxProxyProtocolTLVEntriesPerVector);
++tlvCount;

auto type = extract<uint8_t>(buffer);
uassert(ErrorCodes::FailedToParse,
Expand Down
7 changes: 7 additions & 0 deletions src/mongo/transport/proxy_protocol_header_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ constexpr uint8_t kProxyProtocolTypeAuthority = 0x02;
*/
constexpr uint8_t kProxyProtocolSSLTlvType = 0x20;

/**
* Maximum TLV entries parsed from a single proxy protocol TLV vector, including SSL sub-TLV
* vectors. This bounds per-connection allocations on the proxy unix-socket path while remaining
* well above expected production metadata usage.
*/
constexpr size_t kMaxProxyProtocolTLVEntriesPerVector = 64;

/**
* MongoDB custom PP2 TLV type as per MongoDB Proxy Protocol Technical Design document.
* The kProxyProtocolSSLTlvDN TLV is used to indicate the distinguished name (DN) from the client's
Expand Down
65 changes: 64 additions & 1 deletion src/mongo/transport/proxy_protocol_header_parser_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,19 @@ std::string buildTLV(uint8_t type, const std::string& data) {
return tlv;
}

std::string buildRepeatedTLVs(size_t count, uint8_t type = 0xE0, StringData data = StringData()) {
std::string tlvs;
tlvs.reserve(count * (3 + data.size()));
for (size_t i = 0; i < count; ++i) {
tlvs += buildTLV(type, std::string{data});
}
return tlvs;
}

std::string buildSSLTLVPayload(uint8_t clientFlags,
uint32_t verify,
const std::string& subTlvData = "");

class ProxyProtocolParameterizedTestFixture : public testing::TestWithParam<AddressFamily> {};

INSTANTIATE_TEST_SUITE_P(ProxyProtocolHeaderParser,
Expand Down Expand Up @@ -680,6 +693,56 @@ TEST_P(ProxyProtocolParameterizedTestFixture, TLVParsingManyTLVs) {
ASSERT_EQ(result->tlvs[4].data, "custom");
}

TEST(ProxyProtocolHeaderParser, TLVParsingAcceptsBoundedEntryCounts) {
for (const auto tlvCount :
{size_t{1}, size_t{16}, kMaxProxyProtocolTLVEntriesPerVector}) {
auto result = parseWithTLV(AddressFamily::TCP4, buildRepeatedTLVs(tlvCount, 0xE0, "x"_sd));
ASSERT_TRUE(result) << tlvCount;
ASSERT_TRUE(result->endpoints) << tlvCount;
ASSERT_EQ(result->tlvs.size(), tlvCount) << tlvCount;
ASSERT_FALSE(result->sslTlvs) << tlvCount;
}
}

TEST(ProxyProtocolHeaderParser, TLVParsingRejectsExcessiveEntryCounts) {
ASSERT_THROWS_WITH_CHECK(
parseWithTLV(AddressFamily::TCP4,
buildRepeatedTLVs(kMaxProxyProtocolTLVEntriesPerVector + 1, 0xE0, "x"_sd)),
DBException,
[](const DBException& ex) {
ASSERT_THAT(ex.toStatus(),
StatusIs(Eq(ErrorCodes::FailedToParse),
ContainsRegex("TLV entry count exceeds")));
});
}

TEST(ProxyProtocolHeaderParser, TLVParsingRejectsDenseEntriesAfterSingleSslTLV) {
auto sslTlv = buildTLV(kProxyProtocolSSLTlvType, buildSSLTLVPayload(0x07, 0));
auto denseTlvs = buildRepeatedTLVs(kMaxProxyProtocolTLVEntriesPerVector, 0xE0, "x"_sd);

ASSERT_THROWS_WITH_CHECK(parseWithTLV(AddressFamily::TCP4, sslTlv + denseTlvs),
DBException,
[](const DBException& ex) {
ASSERT_THAT(ex.toStatus(),
StatusIs(Eq(ErrorCodes::FailedToParse),
ContainsRegex("TLV entry count exceeds")));
});
}

TEST(ProxyProtocolHeaderParser, ParseSubTLVVectorsRejectsExcessiveEntryCounts) {
auto subTlvs =
buildRepeatedTLVs(kMaxProxyProtocolTLVEntriesPerVector + 1, 0x21, "x"_sd);
auto sslTlv = buildTLV(kProxyProtocolSSLTlvType, buildSSLTLVPayload(0x07, 0, subTlvs));

ASSERT_THROWS_WITH_CHECK(parseWithTLV(AddressFamily::TCP4, sslTlv),
DBException,
[](const DBException& ex) {
ASSERT_THAT(ex.toStatus(),
StatusIs(Eq(ErrorCodes::FailedToParse),
ContainsRegex("TLV entry count exceeds")));
});
}

TEST_P(ProxyProtocolParameterizedTestFixture, TLVParsingFails) {
auto type = GetParam();

Expand Down Expand Up @@ -736,7 +799,7 @@ TEST_P(ProxyProtocolParameterizedTestFixture, UnixProxySocketWithV1ProtocolIsRej
// The payload format is: clientFlags (1 byte) + verify (4 bytes big-endian) + optional sub-TLVs.
std::string buildSSLTLVPayload(uint8_t clientFlags,
uint32_t verify,
const std::string& subTlvData = "") {
const std::string& subTlvData) {
std::string payload;
payload += static_cast<char>(clientFlags);
// verify in big-endian.
Expand Down