diff --git a/CMakeLists.txt b/CMakeLists.txt index 49831bfebf9..333a86b137d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,7 @@ include(CheckOpenSSLIsBoringSSL) include(CheckOpenSSLIsQuictls) include(CheckOpenSSLIsAwsLc) find_package(OpenSSL REQUIRED) +auto_option(SIMDUTF FEATURE_VAR TS_USE_SIMDUTF PACKAGE_DEPENDS simdutf) check_openssl_is_boringssl(SSLLIB_IS_BORINGSSL BORINGSSL_VERSION "${OPENSSL_INCLUDE_DIR}") check_openssl_is_awslc(SSLLIB_IS_AWSLC AWSLC_VERSION "${OPENSSL_INCLUDE_DIR}") diff --git a/include/tscore/ink_config.h.cmake.in b/include/tscore/ink_config.h.cmake.in index 73c8b860fb9..988ed5cb8d0 100644 --- a/include/tscore/ink_config.h.cmake.in +++ b/include/tscore/ink_config.h.cmake.in @@ -163,6 +163,7 @@ const int DEFAULT_STACKSIZE = @DEFAULT_STACK_SIZE@; #cmakedefine01 TS_USE_POSIX_CAP #cmakedefine01 TS_USE_QUIC #cmakedefine01 TS_USE_REMOTE_UNWINDING +#cmakedefine01 TS_USE_SIMDUTF #cmakedefine01 TS_USE_TLS13 #cmakedefine01 TS_USE_TLS_ASYNC #cmakedefine01 TS_USE_TPROXY diff --git a/src/tscore/CMakeLists.txt b/src/tscore/CMakeLists.txt index 7790adc87dd..6fa2b55594c 100644 --- a/src/tscore/CMakeLists.txt +++ b/src/tscore/CMakeLists.txt @@ -110,6 +110,10 @@ target_link_libraries( tscore PUBLIC OpenSSL::Crypto libswoc::libswoc yaml-cpp::yaml-cpp systemtap::systemtap resolv::resolv ts::tsutil ) +if(TS_USE_SIMDUTF) + target_link_libraries(tscore PUBLIC simdutf::simdutf) +endif() + if(TS_USE_POSIX_CAP) target_link_libraries(tscore PUBLIC cap::cap) endif() diff --git a/src/tscore/ink_base64.cc b/src/tscore/ink_base64.cc index 849d7c8ce83..bf581569496 100644 --- a/src/tscore/ink_base64.cc +++ b/src/tscore/ink_base64.cc @@ -1,6 +1,41 @@ /** @file - A brief file description + Base64 encoding and decoding. + + The public entry points (`ats_base64_encode` / `ats_base64_decode`, also + exposed through `TSBase64Encode` / `TSBase64Decode`) dispatch between two + internal implementations: + + - A hand-rolled scalar path, always present, used directly when + TS_USE_SIMDUTF is disabled, and used for inputs below the SIMD + crossover threshold when TS_USE_SIMDUTF is enabled. The scalar path + avoids simdutf's runtime ISA dispatch and virtual-call overhead, + which would otherwise dominate the cost for tiny inputs (e.g. the + 8-byte SnowflakeID encode). + + - simdutf, used for larger inputs when TS_USE_SIMDUTF is enabled. + simdutf provides SIMD-accelerated kernels and is several times + faster than the scalar path once the input is big enough to amortize + its per-call overhead. + + Thresholds were chosen empirically on a 2.1 GHz Broadwell-EP Xeon + (AVX2) using tools/benchmark/benchmark_ink_base64. The exact crossover + shifts on different cores but lies within an order of magnitude of these + values everywhere we've measured. + + Both paths preserve the same public contract: + + - encode: standard RFC 1521 alphabet (`+`, `/`), `=` padding, no line + breaks, trailing NUL written at outBuffer[length]. + - decode: accepts both standard (`+`, `/`) and URL-safe (`-`, `_`) + alphabets in the same input; tolerates missing padding; on an + invalid character, truncates and returns success with whatever was + decoded up to that point; trailing NUL written at outBuffer[length]. + + Note: simdutf's forgiving-base64 mode silently skips ASCII whitespace + (space, tab, CR, LF, FF) inside the input, whereas the scalar path + treats whitespace as an end-of-input marker. No caller in-tree feeds + whitespace to these functions. @section license License @@ -20,32 +55,50 @@ See the License for the specific language governing permissions and limitations under the License. */ - -/* - * Base64 encoding and decoding as according to RFC1521. Similar to uudecode. - * - * RFC 1521 requires inserting line breaks for long lines. The basic web - * authentication scheme does not require them. This implementation is - * intended for web-related use, and line breaks are not implemented. - * - */ #include "tscore/ink_platform.h" #include "tscore/ink_base64.h" #include "tscore/ink_assert.h" -// TODO: The code here seems a bit klunky, and could probably be improved a bit. +#if TS_USE_SIMDUTF +#include -bool -ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +// Inputs at or below these byte counts stay on the scalar path, where they +// outrun simdutf's per-call overhead. inBufferSize for encode is the binary +// plaintext length; for decode it is the base64-encoded length. +constexpr size_t BASE64_ENCODE_SIMD_THRESHOLD = 24; +constexpr size_t BASE64_DECODE_SIMD_THRESHOLD = 48; +#endif + +namespace +{ + +/* Converts a printable character to its six bit representation. */ +const unsigned char printableToSixBit[256] = { + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, + 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, + 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; + +constexpr unsigned char MAX_PRINT_VAL = 63; + +inline unsigned char +decode_byte(char c) +{ + return printableToSixBit[static_cast(c)]; +} + +// Hand-rolled scalar encode. Caller has already validated outBufSize. +void +encode_scalar(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t *length) { static const char _codes[66] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; char *obuf = outBuffer; char in_tail[4]; - if (outBufSize < ats_base64_encode_dstlen(inBufferSize)) { - return false; - } - while (inBufferSize > 2) { *obuf++ = _codes[(inBuffer[0] >> 2) & 077]; *obuf++ = _codes[((inBuffer[0] & 03) << 4) | ((inBuffer[1] >> 4) & 017)]; @@ -56,14 +109,6 @@ ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outB inBuffer += 3; } - /* - * We've done all the input groups of three chars. We're left - * with 0, 1, or 2 input chars. We have to add zero-bits to the - * right if we don't have enough input chars. - * If 0 chars left, we're done. - * If 1 char left, form 2 output chars, and add 2 pad chars to output. - * If 2 chars left, form 3 output chars, add 1 pad char to output. - */ if (inBufferSize == 0) { *obuf = '\0'; if (length) { @@ -88,60 +133,26 @@ ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outB *length = (obuf + 4) - outBuffer; } } - - return true; -} - -bool -ats_base64_encode(const char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) -{ - return ats_base64_encode(reinterpret_cast(inBuffer), inBufferSize, outBuffer, outBufSize, length); } -/*------------------------------------------------------------------------- - This is a reentrant, and malloc free implementation of ats_base64_decode. - -------------------------------------------------------------------------*/ -#ifdef DECODE -#undef DECODE -#endif - -#define DECODE(x) printableToSixBit[(unsigned char)x] -#define MAX_PRINT_VAL 63 - -/* Converts a printable character to it's six bit representation */ -const unsigned char printableToSixBit[256] = { - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, - 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, - 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; - -bool -ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t outBufSize, size_t *length) +// Hand-rolled scalar decode. Caller has already validated outBufSize. +void +decode_scalar(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t *length) { size_t inBytes = 0; size_t decodedBytes = 0; unsigned char *buf = outBuffer; int inputBytesDecoded = 0; - // Make sure there is sufficient space in the output buffer - if (outBufSize < ats_base64_decode_dstlen(inBufferSize)) { - return false; - } - // Ignore any trailing ='s or other undecodable characters. - // TODO: Perhaps that ought to be an error instead? - while (inBytes < inBufferSize && printableToSixBit[static_cast(inBuffer[inBytes])] <= MAX_PRINT_VAL) { + while (inBytes < inBufferSize && decode_byte(inBuffer[inBytes]) <= MAX_PRINT_VAL) { ++inBytes; } for (size_t i = 0; i < inBytes; i += 4) { - buf[0] = static_cast(DECODE(inBuffer[0]) << 2 | DECODE(inBuffer[1]) >> 4); - buf[1] = static_cast(DECODE(inBuffer[1]) << 4 | DECODE(inBuffer[2]) >> 2); - buf[2] = static_cast(DECODE(inBuffer[2]) << 6 | DECODE(inBuffer[3])); + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + buf[2] = static_cast(decode_byte(inBuffer[2]) << 6 | decode_byte(inBuffer[3])); buf += 3; inBuffer += 4; @@ -149,10 +160,10 @@ ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outB inputBytesDecoded += 4; } - // Check to see if we decoded a multiple of 4 four - // bytes + // If the consumed input wasn't a multiple of 4 we over-counted the last + // group; trim the trailing 1 or 2 bytes back off. if ((inBytes - inputBytesDecoded) & 0x3) { - if (DECODE(inBuffer[-2]) > MAX_PRINT_VAL) { + if (decode_byte(inBuffer[-2]) > MAX_PRINT_VAL) { decodedBytes -= 2; } else { decodedBytes -= 1; @@ -163,6 +174,69 @@ ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outB if (length) { *length = decodedBytes; } +} + +} // namespace + +bool +ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +{ + if (outBufSize < ats_base64_encode_dstlen(inBufferSize)) { + return false; + } + +#if TS_USE_SIMDUTF + if (inBufferSize > BASE64_ENCODE_SIMD_THRESHOLD) { + size_t written = simdutf::binary_to_base64(reinterpret_cast(inBuffer), inBufferSize, outBuffer); + outBuffer[written] = '\0'; + if (length) { + *length = written; + } + return true; + } +#endif + + encode_scalar(inBuffer, inBufferSize, outBuffer, length); + return true; +} + +bool +ats_base64_encode(const char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) +{ + return ats_base64_encode(reinterpret_cast(inBuffer), inBufferSize, outBuffer, outBufSize, length); +} + +bool +ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t outBufSize, size_t *length) +{ + if (outBufSize < ats_base64_decode_dstlen(inBufferSize)) { + return false; + } + +#if TS_USE_SIMDUTF + if (inBufferSize > BASE64_DECODE_SIMD_THRESHOLD) { + // Reserve one byte for the trailing NUL we always emit. + size_t out_len = outBufSize - 1; + auto r = simdutf::base64_to_binary_safe(inBuffer, inBufferSize, reinterpret_cast(outBuffer), out_len, + simdutf::base64_default_or_url, simdutf::last_chunk_handling_options::loose, + /*decode_up_to_bad_char=*/true); + + // OUTPUT_BUFFER_TOO_SMALL is impossible given the upfront dstlen check; + // be defensive anyway. INVALID_BASE64_CHARACTER is expected: scalar + // behavior truncated at bad chars without surfacing an error, so we do + // the same. + if (r.error == simdutf::error_code::OUTPUT_BUFFER_TOO_SMALL) { + return false; + } + + outBuffer[out_len] = '\0'; + if (length) { + *length = out_len; + } + return true; + } +#endif + decode_scalar(inBuffer, inBufferSize, outBuffer, length); return true; } diff --git a/tools/benchmark/CMakeLists.txt b/tools/benchmark/CMakeLists.txt index 49f25fad1c1..d6fa658ca9e 100644 --- a/tools/benchmark/CMakeLists.txt +++ b/tools/benchmark/CMakeLists.txt @@ -36,6 +36,9 @@ target_link_libraries(benchmark_SharedMutex PRIVATE Catch2::Catch2 ts::tscore li add_executable(benchmark_Random benchmark_Random.cc) target_link_libraries(benchmark_Random PRIVATE Catch2::Catch2WithMain ts::tscore) +add_executable(benchmark_ink_base64 benchmark_ink_base64.cc) +target_link_libraries(benchmark_ink_base64 PRIVATE Catch2::Catch2WithMain ts::tscore) + add_executable(benchmark_HostDB benchmark_HostDB.cc) target_link_libraries( benchmark_HostDB diff --git a/tools/benchmark/benchmark_ink_base64.cc b/tools/benchmark/benchmark_ink_base64.cc new file mode 100644 index 00000000000..bec34456ab8 --- /dev/null +++ b/tools/benchmark/benchmark_ink_base64.cc @@ -0,0 +1,205 @@ +/** @file + + Micro benchmark for ats_base64_encode / ats_base64_decode and the bulk + scalar tolower path used by URL canonicalization. Establishes a baseline + prior to any SIMD work. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#define CATCH_CONFIG_ENABLE_BENCHMARKING + +#include +#include + +#include "tscore/ink_base64.h" +#include "tscore/ParseRules.h" + +#include +#include +#include +#include +#include +#include + +namespace +{ + +// Sizes chosen to mirror real callers and to bracket the scalar↔SIMD +// crossover. +// 8B - SnowflakeID (uint64_t) +// 16-48B - HMAC-SHA1/SHA256 and crossover region for encode +// 64-128B - crossover region for decode +// 200B - typical OCSP DER request (RFC6960 caps at 255B encoded) +// 512B / 4KB - stress the inner loop where SIMD wins most +constexpr std::array kPayloadSizes{8, 16, 24, 32, 48, 64, 96, 200, 512, 4096}; + +std::vector +make_random_bytes(size_t n, uint64_t seed = 0xC0FFEEULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (size_t i = 0; i < n; ++i) { + v[i] = static_cast(rng() & 0xFFU); + } + return v; +} + +std::string +encode_with_ats(const std::vector &in) +{ + std::string out; + out.resize(ats_base64_encode_dstlen(in.size())); + size_t n = 0; + bool ok = ats_base64_encode(in.data(), in.size(), out.data(), out.size(), &n); + REQUIRE(ok); + out.resize(n); + return out; +} + +std::vector +make_mixed_case_ascii(size_t n, uint64_t seed = 0xABCDEFULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (size_t i = 0; i < n; ++i) { + // Mix of uppercase, lowercase, and a few non-letter bytes that should + // pass through tolower unchanged. Models a URL/header byte stream. + auto r = static_cast(rng() & 0x3FU); + if (r < 26U) { + v[i] = static_cast('A' + r); + } else if (r < 52U) { + v[i] = static_cast('a' + (r - 26U)); + } else { + static constexpr char kNonAlpha[] = "0123456789-_./:"; + v[i] = kNonAlpha[r % (sizeof(kNonAlpha) - 1U)]; + } + } + return v; +} + +// Equivalent of the static inline memcpy_tolower() in src/proxy/hdrs/URL.cc. +// Reproduced here because that definition has internal linkage and isn't +// reachable from this TU. +inline void +memcpy_tolower_scalar(char *d, const char *s, int n) +{ + while (n--) { + *d = ParseRules::ink_tolower(*s); + ++s; + ++d; + } +} + +} // namespace + +TEST_CASE("ats_base64 round-trip correctness", "[base64][correctness]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + auto encoded = encode_with_ats(input); + std::vector decoded(ats_base64_decode_dstlen(encoded.size()) + 1); + size_t dec_len = 0; + REQUIRE(ats_base64_decode(encoded.data(), encoded.size(), decoded.data(), decoded.size(), &dec_len)); + REQUIRE(dec_len == sz); + REQUIRE(std::memcmp(decoded.data(), input.data(), sz) == 0); + } +} + +// Lock the same byte-exact fixture used by InkAPITest's SDK_API_ENCODING +// regression test. Any future implementation swap must keep this passing. +TEST_CASE("ats_base64 InkAPITest fixture", "[base64][correctness][fixture]") +{ + const char *url = "http://www.example.com/foo?fie= \"#%<>[]\\^`{}~&bar={test}&fum=Apache Traffic Server"; + const char *url_b64 = + "aHR0cDovL3d3dy5leGFtcGxlLmNvbS9mb28/ZmllPSAiIyU8PltdXF5ge31+JmJhcj17dGVzdH0mZnVtPUFwYWNoZSBUcmFmZmljIFNlcnZlcg=="; + const auto url_len = std::strlen(url); + const auto url_b64_len = std::strlen(url_b64); + + SECTION("encode produces byte-identical RFC1521 output with '=' padding") + { + std::array buf{}; + size_t enc_len = 0; + REQUIRE(ats_base64_encode(url, url_len, buf.data(), buf.size(), &enc_len)); + REQUIRE(enc_len == url_b64_len); + REQUIRE(std::strcmp(buf.data(), url_b64) == 0); + } + + SECTION("decode reproduces the original byte-for-byte") + { + std::array buf{}; + size_t dec_len = 0; + REQUIRE(ats_base64_decode(url_b64, url_b64_len, reinterpret_cast(buf.data()), buf.size(), &dec_len)); + REQUIRE(dec_len == url_len); + REQUIRE(std::strcmp(buf.data(), url) == 0); + } +} + +TEST_CASE("ats_base64_encode throughput", "[bench][base64][encode]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + std::vector output(ats_base64_encode_dstlen(sz) + 16); + + std::string name = "encode " + std::to_string(sz) + "B"; + BENCHMARK(name.c_str()) + { + size_t out_len = 0; + bool ok = ats_base64_encode(input.data(), input.size(), output.data(), output.size(), &out_len); + // Return a value that depends on the work to prevent DCE. + return ok ? out_len : size_t{0}; + }; + } +} + +TEST_CASE("ats_base64_decode throughput", "[bench][base64][decode]") +{ + for (size_t sz : kPayloadSizes) { + auto input = make_random_bytes(sz); + auto encoded = encode_with_ats(input); + std::vector output(ats_base64_decode_dstlen(encoded.size()) + 16); + + // Name reports the *plaintext* size so it lines up with the encode bench. + std::string name = "decode " + std::to_string(sz) + "B (" + std::to_string(encoded.size()) + "B b64)"; + BENCHMARK(name.c_str()) + { + size_t out_len = 0; + bool ok = ats_base64_decode(encoded.data(), encoded.size(), output.data(), output.size(), &out_len); + return ok ? out_len : size_t{0}; + }; + } +} + +TEST_CASE("memcpy_tolower throughput", "[bench][tolower]") +{ + // Sizes chosen to model URL paths / header names / cache-key segments. + constexpr std::array kTolowerSizes{16, 64, 256, 1024}; + + for (size_t sz : kTolowerSizes) { + auto input = make_mixed_case_ascii(sz); + std::vector output(sz); + + std::string name = "tolower " + std::to_string(sz) + "B"; + BENCHMARK(name.c_str()) + { + memcpy_tolower_scalar(output.data(), input.data(), static_cast(sz)); + return output[0]; + }; + } +}