From 330f3ea487e6ba63878b1f250b2f6af5278fafd1 Mon Sep 17 00:00:00 2001 From: Hongzhi Gao <761417898@qq.com> Date: Wed, 15 Apr 2026 15:05:16 +0800 Subject: [PATCH] [To dev/1.3] enhance cppclient tsblock deserialize validation (#17464) * fix tsblock deserialize * fix ut error on win * Revert "fix ut error on win" This reverts commit 34b8de482c864abd64d721bdd51fe08e18b742b8. --- .../client-cpp/src/main/ColumnDecoder.cpp | 3 + iotdb-client/client-cpp/src/main/Common.cpp | 46 ++++++++++++- iotdb-client/client-cpp/src/main/Common.h | 8 +++ iotdb-client/client-cpp/src/main/TsBlock.cpp | 12 ++++ .../client-cpp/src/test/cpp/sessionIT.cpp | 67 +++++++++++++++++++ 5 files changed, 134 insertions(+), 2 deletions(-) diff --git a/iotdb-client/client-cpp/src/main/ColumnDecoder.cpp b/iotdb-client/client-cpp/src/main/ColumnDecoder.cpp index e45cb49409a0d..32f29d876f368 100644 --- a/iotdb-client/client-cpp/src/main/ColumnDecoder.cpp +++ b/iotdb-client/client-cpp/src/main/ColumnDecoder.cpp @@ -151,6 +151,9 @@ std::unique_ptr BinaryArrayColumnDecoder::readColumn( if (!nullIndicators.empty() && nullIndicators[i]) continue; int32_t length = buffer.getInt(); + if (length < 0) { + throw IoTDBException("BinaryArrayColumnDecoder: negative TEXT length"); + } std::vector value(length); for (int32_t j = 0; j < length; j++) { diff --git a/iotdb-client/client-cpp/src/main/Common.cpp b/iotdb-client/client-cpp/src/main/Common.cpp index 38e8a31d2ff37..f58b6cc21ddc1 100644 --- a/iotdb-client/client-cpp/src/main/Common.cpp +++ b/iotdb-client/client-cpp/src/main/Common.cpp @@ -19,6 +19,7 @@ #include "Common.h" #include +#include int32_t parseDateExpressionToInt(const boost::gregorian::date& date) { if (date.is_not_a_date()) { @@ -292,6 +293,10 @@ double MyStringBuffer::getDouble() { } char MyStringBuffer::getChar() { + if (pos >= str.size()) { + throw IoTDBException("MyStringBuffer::getChar: read past end (pos=" + std::to_string(pos) + + ", size=" + std::to_string(str.size()) + ")"); + } return str[pos++]; } @@ -300,8 +305,16 @@ bool MyStringBuffer::getBool() { } std::string MyStringBuffer::getString() { - size_t len = getInt(); - size_t tmpPos = pos; + const int lenInt = getInt(); + if (lenInt < 0) { + throw IoTDBException("MyStringBuffer::getString: negative length"); + } + const size_t len = static_cast(lenInt); + if (pos > str.size() || len > str.size() - pos) { + throw IoTDBException("MyStringBuffer::getString: length exceeds buffer (pos=" + std::to_string(pos) + + ", len=" + std::to_string(len) + ", size=" + std::to_string(str.size()) + ")"); + } + const size_t tmpPos = pos; pos += len; return str.substr(tmpPos, len); } @@ -350,6 +363,10 @@ void MyStringBuffer::checkBigEndian() { } const char* MyStringBuffer::getOrderedByte(size_t len) { + if (pos > str.size() || len > str.size() - pos) { + throw IoTDBException("MyStringBuffer::getOrderedByte: read past end (pos=" + std::to_string(pos) + + ", len=" + std::to_string(len) + ", size=" + std::to_string(str.size()) + ")"); + } const char* p = nullptr; if (isBigEndian) { p = str.c_str() + pos; @@ -454,3 +471,28 @@ const std::vector& BitMap::getByteArray() const { size_t BitMap::getSize() const { return this->size; } + +TEndPoint UrlUtils::parseTEndPointIpv4AndIpv6Url(const std::string& endPointUrl) { + TEndPoint endPoint; + const size_t colonPos = endPointUrl.find_last_of(':'); + if (colonPos == std::string::npos) { + endPoint.__set_ip(endPointUrl); + endPoint.__set_port(0); + return endPoint; + } + std::string ip = endPointUrl.substr(0, colonPos); + const std::string portStr = endPointUrl.substr(colonPos + 1); + try { + const int port = std::stoi(portStr); + endPoint.__set_port(port); + } catch (const std::logic_error&) { + endPoint.__set_ip(endPointUrl); + endPoint.__set_port(0); + return endPoint; + } + if (ip.size() >= 2 && ip.front() == '[' && ip.back() == ']') { + ip = ip.substr(1, ip.size() - 2); + } + endPoint.__set_ip(ip); + return endPoint; +} diff --git a/iotdb-client/client-cpp/src/main/Common.h b/iotdb-client/client-cpp/src/main/Common.h index a9f4552ecc5fd..af9cf46e62e54 100644 --- a/iotdb-client/client-cpp/src/main/Common.h +++ b/iotdb-client/client-cpp/src/main/Common.h @@ -480,5 +480,13 @@ class RpcUtils { static std::shared_ptr getTSFetchResultsResp(const TSStatus& status); }; +class UrlUtils { +public: + UrlUtils() = delete; + + /** Parse host:port; aligns with Java UrlUtils.parseTEndPointIpv4AndIpv6Url plus test edge cases. */ + static TEndPoint parseTEndPointIpv4AndIpv6Url(const std::string& endPointUrl); +}; + #endif diff --git a/iotdb-client/client-cpp/src/main/TsBlock.cpp b/iotdb-client/client-cpp/src/main/TsBlock.cpp index 7c2bac272a601..92afbef3f270f 100644 --- a/iotdb-client/client-cpp/src/main/TsBlock.cpp +++ b/iotdb-client/client-cpp/src/main/TsBlock.cpp @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include "TsBlock.h" @@ -34,6 +35,14 @@ std::shared_ptr TsBlock::deserialize(const std::string& data) { // Read value column count int32_t valueColumnCount = buffer.getInt(); + if (valueColumnCount < 0) { + throw IoTDBException("TsBlock::deserialize: negative valueColumnCount"); + } + const int64_t minHeaderBytes = + 9LL + 2LL * static_cast(valueColumnCount); + if (minHeaderBytes > static_cast(data.size())) { + throw IoTDBException("TsBlock::deserialize: truncated header"); + } // Read value column data types std::vector valueColumnDataTypes(valueColumnCount); @@ -43,6 +52,9 @@ std::shared_ptr TsBlock::deserialize(const std::string& data) { // Read position count int32_t positionCount = buffer.getInt(); + if (positionCount < 0) { + throw IoTDBException("TsBlock::deserialize: negative positionCount"); + } // Read column encodings std::vector columnEncodings(valueColumnCount + 1); diff --git a/iotdb-client/client-cpp/src/test/cpp/sessionIT.cpp b/iotdb-client/client-cpp/src/test/cpp/sessionIT.cpp index 09154b868e47c..b2db95cc4eca0 100644 --- a/iotdb-client/client-cpp/src/test/cpp/sessionIT.cpp +++ b/iotdb-client/client-cpp/src/test/cpp/sessionIT.cpp @@ -19,6 +19,8 @@ #include "catch.hpp" #include "Session.h" +#include "TsBlock.h" +#include using namespace std; @@ -728,3 +730,68 @@ TEST_CASE("Test executeLastDataQuery ", "[testExecuteLastDataQuery]") { sessionDataSet->setFetchSize(1024); REQUIRE(sessionDataSet->hasNext() == false); } + +// Helper function for comparing TEndPoint with detailed error message +void assertTEndPointEqual(const TEndPoint& actual, + const std::string& expectedIp, + int expectedPort, + const char* file, + int line) { + if (actual.ip != expectedIp || actual.port != expectedPort) { + std::stringstream ss; + ss << "\nTEndPoint mismatch:\nExpected: " << expectedIp << ":" << expectedPort + << "\nActual: " << actual.ip << ":" << actual.port; + Catch::SourceLineInfo location(file, line); + Catch::AssertionHandler handler("TEndPoint comparison", location, ss.str(), Catch::ResultDisposition::Normal); + handler.handleMessage(Catch::ResultWas::ExplicitFailure, ss.str()); + handler.complete(); + } +} + +// Macro to simplify test assertions +#define REQUIRE_TENDPOINT(actual, expectedIp, expectedPort) \ + assertTEndPointEqual(actual, expectedIp, expectedPort, __FILE__, __LINE__) + +TEST_CASE("UrlUtils - parseTEndPointIpv4AndIpv6Url", "[UrlUtils]") { + // Test valid IPv4 addresses + SECTION("Valid IPv4") { + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("192.168.1.1:8080"), "192.168.1.1", 8080); + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("10.0.0.1:80"), "10.0.0.1", 80); + } + + // Test valid IPv6 addresses + SECTION("Valid IPv6") { + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("[2001:db8::1]:8080"), "2001:db8::1", 8080); + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("[::1]:80"), "::1", 80); + } + + // Test hostnames + SECTION("Hostnames") { + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("localhost:8080"), "localhost", 8080); + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("example.com:443"), "example.com", 443); + } + + // Test edge cases + SECTION("Edge cases") { + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url(""), "", 0); + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("127.0.0.1"), "127.0.0.1", 0); + } + + // Test invalid inputs + SECTION("Invalid inputs") { + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("192.168.1.1:abc"), "192.168.1.1:abc", 0); + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("]invalid[:80"), "]invalid[", 80); + } + + // Test port ranges + SECTION("Port ranges") { + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("localhost:0"), "localhost", 0); + REQUIRE_TENDPOINT(UrlUtils::parseTEndPointIpv4AndIpv6Url("127.0.0.1:65535"), "127.0.0.1", 65535); + } +} + +TEST_CASE("TsBlock deserialize rejects truncated malicious payload", "[TsBlockDeserialize]") { + std::string data(18, '\0'); + data[3] = '\x10'; + REQUIRE_THROWS_AS(TsBlock::deserialize(data), IoTDBException); +}