From b9e8516656dcf10ca6618ebd7bf64d0870366036 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Fri, 29 May 2026 16:41:19 +0200 Subject: [PATCH] refactor(varint): parametrize the canonical LEB128 codec with max_bytes The networking and snappy modules each carried a self-contained LEB128 implementation. The two algorithms were identical, differing only in byte cap (10 vs 5). One implementation was a slow drift away from the other waiting to happen. Keep the networking varint as the single source of truth and give both encode and decode a max_bytes parameter (default 10 = uint64 cap). Snappy now imports the canonical codec and passes 5 via a new SNAPPY_VARINT_MAX_BYTES constant. Encode now also enforces the cap, so per-cap semantics are symmetric. The snappy length-prefix tests run against the canonical codec, plus an integration test that asserts an oversize prefix surfaces as a SnappyDecompressionError. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lean_spec/node/networking/varint.py | 37 ++-- src/lean_spec/node/snappy/compress.py | 9 +- src/lean_spec/node/snappy/constants.py | 14 +- src/lean_spec/node/snappy/decompress.py | 11 +- src/lean_spec/node/snappy/encoding.py | 159 +----------------- .../lean_spec/node/networking/test_reqresp.py | 2 +- .../lean_spec/node/networking/test_varint.py | 43 ++++- tests/lean_spec/node/snappy/test_snappy.py | 52 +++--- 8 files changed, 126 insertions(+), 201 deletions(-) diff --git a/src/lean_spec/node/networking/varint.py b/src/lean_spec/node/networking/varint.py index 4e419b40a..8624e3b8c 100644 --- a/src/lean_spec/node/networking/varint.py +++ b/src/lean_spec/node/networking/varint.py @@ -114,7 +114,7 @@ class VarintError(Exception): """Raised when varint encoding or decoding fails.""" -def encode_varint(value: int) -> bytes: +def encode_varint(value: int, max_bytes: int = 10) -> bytes: """ Encode an unsigned integer as LEB128 varint. @@ -122,7 +122,10 @@ def encode_varint(value: int) -> bytes: All bytes except the last have the continuation bit (0x80) set. Args: - value: Non-negative integer to encode. Maximum: 2^64 - 1. + value: Non-negative integer to encode. + max_bytes: Upper bound on the encoded byte count. + Defaults to 10, which is the cap for a 64-bit value. + Pass 5 for a 32-bit cap, matching the snappy length prefix. Returns: Varint-encoded bytes. Length depends on value: @@ -130,14 +133,21 @@ def encode_varint(value: int) -> bytes: - 0-127: 1 byte - 128-16383: 2 bytes - 16384-2097151: 3 bytes - - Up to 10 bytes for 64-bit values + - Up to max_bytes for the largest representable value Raises: - ValueError: If value is negative. + ValueError: If value is negative or does not fit in max_bytes. """ if value < 0: raise ValueError("Varint must be non-negative") + # Reject values that would need more than max_bytes to encode. + # + # Each output byte carries 7 data bits. + # Anything that does not fit in max_bytes * 7 bits is rejected here. + if value >> (7 * max_bytes): + raise ValueError(f"Varint value does not fit in {max_bytes} bytes") + result = bytearray() # Process 7 bits at a time until the value fits in 7 bits. @@ -160,7 +170,7 @@ def encode_varint(value: int) -> bytes: return bytes(result) -def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]: +def decode_varint(data: bytes, offset: int = 0, max_bytes: int = 10) -> tuple[int, int]: """ Decode a varint from bytes at the given offset. @@ -170,6 +180,9 @@ def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]: Args: data: Input bytes containing the varint. offset: Starting position in data. Defaults to 0. + max_bytes: Upper bound on the encoded byte count. + Defaults to 10, which is the cap for a 64-bit value. + Pass 5 for a 32-bit cap, matching the snappy length prefix. Returns: Tuple of (decoded_value, bytes_consumed). @@ -179,8 +192,8 @@ def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]: Raises: VarintError: If the input is truncated (runs out of bytes - before finding the final byte) or exceeds 10 bytes - (would overflow 64 bits). + before finding the final byte) or exceeds max_bytes + (would overflow the declared range). """ result = 0 shift = 0 @@ -212,10 +225,10 @@ def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]: # Guard against malformed input that never terminates. # - # A 64-bit value needs at most 10 bytes (70 bits, with 6 unused). - # If we've shifted 70+ bits and still see continuation, the input - # is invalid or represents a value larger than we can handle. - if shift >= 70: - raise VarintError("Varint too long") + # A varint capped at max_bytes carries at most max_bytes * 7 bits. + # Once we have already consumed that many bytes and still see + # a continuation bit, the input is invalid or out of range. + if pos - offset >= max_bytes: + raise VarintError(f"Varint exceeds {max_bytes} bytes") return result, pos - offset diff --git a/src/lean_spec/node/snappy/compress.py b/src/lean_spec/node/snappy/compress.py index 4d7e2fa64..d77324bfc 100644 --- a/src/lean_spec/node/snappy/compress.py +++ b/src/lean_spec/node/snappy/compress.py @@ -63,14 +63,17 @@ from __future__ import annotations +from lean_spec.node.networking.varint import encode_varint + from .constants import ( BLOCK_SIZE, HASH_MULTIPLIER, INPUT_MARGIN_BYTES, MAX_HASH_TABLE_BITS, MIN_HASH_TABLE_BITS, + SNAPPY_VARINT_MAX_BYTES, ) -from .encoding import encode_copy_tag, encode_literal_tag, encode_varint32 +from .encoding import encode_copy_tag, encode_literal_tag def compress(data: bytes) -> bytes: @@ -91,13 +94,13 @@ def compress(data: bytes) -> bytes: # # Even empty data needs a length prefix (varint 0). if not data: - return encode_varint32(0) + return encode_varint(0, max_bytes=SNAPPY_VARINT_MAX_BYTES) # Build output buffer. # # Start with the uncompressed length as a varint. # The decompressor reads this first to allocate the output buffer. - output = bytearray(encode_varint32(len(data))) + output = bytearray(encode_varint(len(data), max_bytes=SNAPPY_VARINT_MAX_BYTES)) # Process input in blocks. # diff --git a/src/lean_spec/node/snappy/constants.py b/src/lean_spec/node/snappy/constants.py index 2f130d445..f8b0d7a21 100644 --- a/src/lean_spec/node/snappy/constants.py +++ b/src/lean_spec/node/snappy/constants.py @@ -146,11 +146,13 @@ # # The uncompressed length is encoded as a varint at the start of the -# compressed data. Varints use 7 bits per byte, with the high bit -# indicating continuation. +# compressed data. The shared LEB128 codec from the networking layer +# handles encoding and decoding. The cap below bounds the prefix length +# to the 32-bit range defined by the Snappy format. -VARINT_CONTINUATION_BIT: Final = 0x80 -"""High bit set in varint bytes to indicate more bytes follow.""" +SNAPPY_VARINT_MAX_BYTES: Final = 5 +"""Maximum byte count for the uncompressed length prefix. -VARINT_DATA_MASK: Final = 0x7F -"""Mask to extract the 7 data bits from a varint byte.""" +Five bytes carry thirty-five data bits. +This is the smallest LEB128 length that covers the full 32-bit range +used by the Snappy format for the uncompressed payload size.""" diff --git a/src/lean_spec/node/snappy/decompress.py b/src/lean_spec/node/snappy/decompress.py index 413e344bd..4442e62b5 100644 --- a/src/lean_spec/node/snappy/decompress.py +++ b/src/lean_spec/node/snappy/decompress.py @@ -71,7 +71,10 @@ from __future__ import annotations -from .encoding import decode_tag, decode_varint32 +from lean_spec.node.networking.varint import VarintError, decode_varint + +from .constants import SNAPPY_VARINT_MAX_BYTES +from .encoding import decode_tag class SnappyDecompressionError(Exception): @@ -100,8 +103,10 @@ def decompress(data: bytes) -> bytes: # # Example: data = [0x08, ...] -> length = 8 try: - uncompressed_length, varint_bytes = decode_varint32(data, 0) - except ValueError as e: + uncompressed_length, varint_bytes = decode_varint( + data, 0, max_bytes=SNAPPY_VARINT_MAX_BYTES + ) + except VarintError as e: raise SnappyDecompressionError(f"Invalid length varint: {e}") from e # Length = 0 is valid: the original data was empty. diff --git a/src/lean_spec/node/snappy/encoding.py b/src/lean_spec/node/snappy/encoding.py index 9e14bbf88..917292d9a 100644 --- a/src/lean_spec/node/snappy/encoding.py +++ b/src/lean_spec/node/snappy/encoding.py @@ -4,12 +4,13 @@ This module provides the low-level encoding and decoding primitives used by both the compressor and decompressor: -1. **Varint encoding**: Variable-length integer encoding for the uncompressed - length prefix. Small values use fewer bytes, saving space. - -2. **Tag byte encoding**: Compact representation of literal and copy operations. +1. **Tag byte encoding**: Compact representation of literal and copy operations. The tag byte format packs operation type and length into minimal space. +The uncompressed length prefix uses the shared LEB128 varint codec from the +networking layer with a five-byte cap, matching the 32-bit range used by the +Snappy format specification. + Reference: https://github.com/google/snappy/blob/main/format_description.txt """ @@ -31,161 +32,11 @@ MAX_COPY_2_OFFSET, MAX_INLINE_LITERAL_LENGTH, MIN_COPY_1_LENGTH, - VARINT_CONTINUATION_BIT, - VARINT_DATA_MASK, ) type TagType = Literal["literal", "copy"] """Snappy tag type: either a literal (raw bytes) or a copy (back-reference).""" -# Varint Encoding -# -# Varints encode integers using as few bytes as possible. -# - Small values use fewer bytes. -# - Large values use more. -# -# Each byte has 8 bits: -# - Bit 7 (high): continuation flag. -# - 1 = more bytes follow, -# - 0 = this is the last byte. -# - Bits 0-6 (low): 7 bits of the integer value. -# -# Bytes are emitted least-significant chunk first. -# -# Byte count by value: -# 0 .. 127 -> 1 byte -# 128 .. 16,383 -> 2 bytes -# 16,384 .. 2,097,151 -> 3 bytes -# 2,097,152 .. 268,435,455 -> 4 bytes -# 268,435,456 .. 2^32 - 1 -> 5 bytes -# -# Example: encoding 300 -# -# 300 in binary: 100101100 (9 bits, needs 2 chunks of 7 bits) -# -# Chunk 1 (bits 0-6): 0101100 = 44. More bits remain, so continuation = 1. -# Byte 1 = 0x80 | 44 = 0xAC -# -# Chunk 2 (bits 7+): 0000010 = 2. No more bits, so continuation = 0. -# Byte 2 = 0x00 | 2 = 0x02 -# -# Encoded: [0xAC, 0x02] -# -# Example: decoding [0xAC, 0x02] -# -# For each byte: check bit 7 for continuation, mask with 0x7F to get data. -# -# Byte 1 = 0xAC = 10101100: -# bit 7 = 1 -> more bytes coming -# data = 0xAC & 0x7F = 0101100 = 44 -# result = 44 -# -# Byte 2 = 0x02 = 00000010: -# bit 7 = 0 -> done -# data = 0x02 & 0x7F = 0000010 = 2 (mask has no effect here) -# result = 44 | (2 << 7) = 44 + 256 = 300 - - -def encode_varint32(value: int) -> bytes: - """Encode a 32-bit integer as a variable-length byte sequence. - - The varint format uses 7 bits per byte for data, with the high bit - indicating whether more bytes follow. This efficiently encodes small - values in fewer bytes. - - Algorithm: - 1. Take the lowest 7 bits of the value. - 2. If more bits remain, set the continuation bit (0x80). - 3. Repeat until all bits are encoded. - - Args: - value: Non-negative integer to encode (must fit in 32 bits). - - Returns: - Variable-length bytes encoding the integer (1-5 bytes). - - Raises: - ValueError: If value is negative or exceeds 32 bits. - """ - # Validate input range. - # Varints in Snappy are unsigned 32-bit integers. - if value < 0: - raise ValueError(f"Varint value must be non-negative, got {value}") - if value > 0xFFFFFFFF: - raise ValueError(f"Varint value exceeds 32 bits: {value}") - - # Build the encoding byte by byte. - # We accumulate bytes in a list for efficiency. - result: list[int] = [] - - while True: - # Extract the lowest 7 bits. - byte = value & VARINT_DATA_MASK - - # Shift out the bits we just encoded. - value >>= 7 - - if value != 0: - # More bits remain: set continuation bit. - byte |= VARINT_CONTINUATION_BIT - - result.append(byte) - - if value == 0: - # All bits encoded. - break - - return bytes(result) - - -def decode_varint32(data: bytes, offset: int = 0) -> tuple[int, int]: - """Decode a varint from a byte sequence at the given offset. - - Reads bytes starting at offset, accumulating 7 bits per byte into - the result. Stops when a byte without the continuation bit is found. - - Args: - data: Byte sequence containing the varint. - offset: Position in data where the varint starts. - - Returns: - Tuple of (decoded_value, bytes_consumed). - - Raises: - ValueError: If the varint is malformed (too long or truncated). - """ - result = 0 - shift = 0 - bytes_read = 0 - - while True: - # Check bounds. - if offset + bytes_read >= len(data): - raise ValueError("Truncated varint: unexpected end of data") - - # Read next byte. - byte = data[offset + bytes_read] - bytes_read += 1 - - # Accumulate the 7 data bits at the current shift position. - result |= (byte & VARINT_DATA_MASK) << shift - shift += 7 - - # Check if this is the last byte (no continuation bit). - if (byte & VARINT_CONTINUATION_BIT) == 0: - break - - # Safety check: varints should not exceed 5 bytes for 32-bit values. - # (5 bytes * 7 bits = 35 bits, which covers 32-bit range) - if bytes_read >= 5: - raise ValueError("Varint too long: exceeds 5 bytes") - - # Verify the result fits in 32 bits. - if result > 0xFFFFFFFF: - raise ValueError(f"Varint overflow: {result} exceeds 32 bits") - - return result, bytes_read - # Tag Byte Encoding - Literals # diff --git a/tests/lean_spec/node/networking/test_reqresp.py b/tests/lean_spec/node/networking/test_reqresp.py index 0f956640c..f40edbf80 100644 --- a/tests/lean_spec/node/networking/test_reqresp.py +++ b/tests/lean_spec/node/networking/test_reqresp.py @@ -204,7 +204,7 @@ def test_varint_11_bytes_rejected(self) -> None: malformed = bytes([0x80] * 10 + [0x01]) assert len(malformed) == 11 - with pytest.raises(VarintError, match="too long"): + with pytest.raises(VarintError, match="exceeds 10 bytes"): decode_varint(malformed) def test_payload_at_max_size(self) -> None: diff --git a/tests/lean_spec/node/networking/test_varint.py b/tests/lean_spec/node/networking/test_varint.py index c45aaa7a5..485ea8a05 100644 --- a/tests/lean_spec/node/networking/test_varint.py +++ b/tests/lean_spec/node/networking/test_varint.py @@ -74,7 +74,7 @@ def test_empty_raises(self) -> None: def test_too_long_raises(self) -> None: """More than 10 continuation bytes (>64-bit) raises.""" - with pytest.raises(VarintError, match="too long"): + with pytest.raises(VarintError, match="exceeds 10 bytes"): decode_varint(b"\x80" * 11, 0) @@ -120,3 +120,44 @@ def test_large_values(self, value: int) -> None: """Large multi-byte values roundtrip correctly.""" encoded = encode_varint(value) assert decode_varint(encoded, 0) == (value, len(encoded)) + + +class TestMaxBytesParameter: + """Tests for the max_bytes cap shared by both encode and decode.""" + + @pytest.mark.parametrize( + ("value", "byte_count"), + [ + (0, 1), + (127, 1), + (128, 2), + (16383, 2), + (16384, 3), + (2**28 - 1, 4), + (2**28, 5), + (2**35 - 1, 5), + ], + ) + def test_five_byte_cap_accepts_values_up_to_five_bytes( + self, value: int, byte_count: int + ) -> None: + """A five-byte cap fits values that encode in five or fewer bytes.""" + encoded = encode_varint(value, max_bytes=5) + assert len(encoded) == byte_count + assert decode_varint(encoded, 0, max_bytes=5) == (value, byte_count) + + def test_five_byte_cap_rejects_value_needing_six_bytes(self) -> None: + """A value past the five-byte ceiling is rejected on encode.""" + with pytest.raises(ValueError, match="does not fit in 5 bytes"): + encode_varint(2**35, max_bytes=5) + + def test_five_byte_cap_rejects_six_byte_input(self) -> None: + """A six-byte continuation run is rejected on decode.""" + with pytest.raises(VarintError, match="exceeds 5 bytes"): + decode_varint(b"\x80" * 6, 0, max_bytes=5) + + def test_five_byte_cap_accepts_five_bytes_at_boundary(self) -> None: + """Five continuation bytes followed by a terminator decode successfully.""" + encoded = bytes([0x80, 0x80, 0x80, 0x80, 0x01]) + value, consumed = decode_varint(encoded, 0, max_bytes=5) + assert (value, consumed) == (1 << 28, 5) diff --git a/tests/lean_spec/node/snappy/test_snappy.py b/tests/lean_spec/node/snappy/test_snappy.py index 53e5e7931..e3a2620b8 100644 --- a/tests/lean_spec/node/snappy/test_snappy.py +++ b/tests/lean_spec/node/snappy/test_snappy.py @@ -8,18 +8,18 @@ import pytest +from lean_spec.node.networking.varint import VarintError, decode_varint, encode_varint from lean_spec.node.snappy import ( SnappyDecompressionError, compress, decompress, max_compressed_length, ) +from lean_spec.node.snappy.constants import SNAPPY_VARINT_MAX_BYTES from lean_spec.node.snappy.encoding import ( decode_tag, - decode_varint32, encode_copy_tag, encode_literal_tag, - encode_varint32, ) # Path to test data files @@ -58,61 +58,71 @@ def iter_test_files() -> Iterator[tuple[str, bytes]]: yield label, load_test_file(filename, size_limit) -class TestVarintEncoding: - """Tests for varint encoding/decoding.""" +class TestSnappyLengthPrefix: + """Tests for the LEB128 length prefix used by the snappy block format.""" def test_encode_zero(self) -> None: """Zero encodes to a single null byte.""" - assert encode_varint32(0) == b"\x00" + assert encode_varint(0, max_bytes=SNAPPY_VARINT_MAX_BYTES) == b"\x00" def test_encode_small_values(self) -> None: """Values 0-127 encode to a single byte.""" - assert encode_varint32(1) == b"\x01" - assert encode_varint32(127) == b"\x7f" + assert encode_varint(1, max_bytes=SNAPPY_VARINT_MAX_BYTES) == b"\x01" + assert encode_varint(127, max_bytes=SNAPPY_VARINT_MAX_BYTES) == b"\x7f" def test_encode_two_byte_values(self) -> None: """Values 128-16383 encode to two bytes.""" - assert encode_varint32(128) == b"\x80\x01" - assert encode_varint32(300) == b"\xac\x02" + assert encode_varint(128, max_bytes=SNAPPY_VARINT_MAX_BYTES) == b"\x80\x01" + assert encode_varint(300, max_bytes=SNAPPY_VARINT_MAX_BYTES) == b"\xac\x02" def test_encode_large_values(self) -> None: """Large values encode correctly.""" for value in [65536, 2**20, 2**24, 2**32 - 1]: - encoded = encode_varint32(value) - decoded, _ = decode_varint32(encoded) + encoded = encode_varint(value, max_bytes=SNAPPY_VARINT_MAX_BYTES) + decoded, _ = decode_varint(encoded, max_bytes=SNAPPY_VARINT_MAX_BYTES) assert decoded == value def test_decode_roundtrip(self) -> None: """Encoding then decoding returns the original value.""" test_values = [0, 1, 127, 128, 255, 256, 16383, 16384, 65535, 65536, 2**20, 2**32 - 1] for value in test_values: - encoded = encode_varint32(value) - decoded, bytes_consumed = decode_varint32(encoded) + encoded = encode_varint(value, max_bytes=SNAPPY_VARINT_MAX_BYTES) + decoded, bytes_consumed = decode_varint(encoded, max_bytes=SNAPPY_VARINT_MAX_BYTES) assert decoded == value assert bytes_consumed == len(encoded) def test_encode_negative_raises(self) -> None: """Negative values raise ValueError.""" with pytest.raises(ValueError, match="non-negative"): - encode_varint32(-1) + encode_varint(-1, max_bytes=SNAPPY_VARINT_MAX_BYTES) def test_encode_overflow_raises(self) -> None: - """Values exceeding 32 bits raise ValueError.""" - with pytest.raises(ValueError, match="32 bits"): - encode_varint32(2**32) + """Values past the five-byte cap raise ValueError.""" + with pytest.raises(ValueError, match="does not fit in 5 bytes"): + encode_varint(2**35, max_bytes=SNAPPY_VARINT_MAX_BYTES) def test_decode_truncated_raises(self) -> None: - """Truncated varints raise ValueError.""" - with pytest.raises(ValueError, match="Truncated"): - decode_varint32(b"\x80") + """Truncated varints raise a varint error.""" + with pytest.raises(VarintError, match="Truncated"): + decode_varint(b"\x80", max_bytes=SNAPPY_VARINT_MAX_BYTES) + + def test_decode_too_long_raises(self) -> None: + """A six-byte continuation run exceeds the snappy cap.""" + with pytest.raises(VarintError, match="exceeds 5 bytes"): + decode_varint(b"\x80" * 6, max_bytes=SNAPPY_VARINT_MAX_BYTES) def test_decode_with_offset(self) -> None: """Decoding at an offset works correctly.""" data = b"prefix\xac\x02suffix" - value, consumed = decode_varint32(data, offset=6) + value, consumed = decode_varint(data, offset=6, max_bytes=SNAPPY_VARINT_MAX_BYTES) assert value == 300 assert consumed == 2 + def test_decompress_wraps_oversize_prefix(self) -> None: + """A length prefix that overruns the cap surfaces as a decompression error.""" + with pytest.raises(SnappyDecompressionError, match="Invalid length varint"): + decompress(b"\x80" * 6) + class TestTagEncoding: """Tests for literal and copy tag encoding/decoding."""