Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions src/lean_spec/node/networking/varint.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,40 @@ class VarintError(Exception):
"""Raised when varint encoding or decoding fails."""


def encode_varint(value: int) -> bytes:
def encode_varint(value: int, max_bytes: int = 10) -> bytes:
"""
Encode an unsigned integer as LEB128 varint.

Splits the integer into 7-bit groups, emitting each as one byte.
All bytes except the last have the continuation bit (0x80) set.

Args:
value: Non-negative integer to encode. Maximum: 2^64 - 1.
value: Non-negative integer to encode.
max_bytes: Upper bound on the encoded byte count.
Defaults to 10, which is the cap for a 64-bit value.
Pass 5 for a 32-bit cap, matching the snappy length prefix.

Returns:
Varint-encoded bytes. Length depends on value:

- 0-127: 1 byte
- 128-16383: 2 bytes
- 16384-2097151: 3 bytes
- Up to 10 bytes for 64-bit values
- Up to max_bytes for the largest representable value

Raises:
ValueError: If value is negative.
ValueError: If value is negative or does not fit in max_bytes.
"""
if value < 0:
raise ValueError("Varint must be non-negative")

# Reject values that would need more than max_bytes to encode.
#
# Each output byte carries 7 data bits.
# Anything that does not fit in max_bytes * 7 bits is rejected here.
if value >> (7 * max_bytes):
raise ValueError(f"Varint value does not fit in {max_bytes} bytes")

result = bytearray()

# Process 7 bits at a time until the value fits in 7 bits.
Expand All @@ -160,7 +170,7 @@ def encode_varint(value: int) -> bytes:
return bytes(result)


def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
def decode_varint(data: bytes, offset: int = 0, max_bytes: int = 10) -> tuple[int, int]:
"""
Decode a varint from bytes at the given offset.

Expand All @@ -170,6 +180,9 @@ def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
Args:
data: Input bytes containing the varint.
offset: Starting position in data. Defaults to 0.
max_bytes: Upper bound on the encoded byte count.
Defaults to 10, which is the cap for a 64-bit value.
Pass 5 for a 32-bit cap, matching the snappy length prefix.

Returns:
Tuple of (decoded_value, bytes_consumed).
Expand All @@ -179,8 +192,8 @@ def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:

Raises:
VarintError: If the input is truncated (runs out of bytes
before finding the final byte) or exceeds 10 bytes
(would overflow 64 bits).
before finding the final byte) or exceeds max_bytes
(would overflow the declared range).
"""
result = 0
shift = 0
Expand Down Expand Up @@ -212,10 +225,10 @@ def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:

# Guard against malformed input that never terminates.
#
# A 64-bit value needs at most 10 bytes (70 bits, with 6 unused).
# If we've shifted 70+ bits and still see continuation, the input
# is invalid or represents a value larger than we can handle.
if shift >= 70:
raise VarintError("Varint too long")
# A varint capped at max_bytes carries at most max_bytes * 7 bits.
# Once we have already consumed that many bytes and still see
# a continuation bit, the input is invalid or out of range.
if pos - offset >= max_bytes:
raise VarintError(f"Varint exceeds {max_bytes} bytes")

return result, pos - offset
9 changes: 6 additions & 3 deletions src/lean_spec/node/snappy/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,17 @@

from __future__ import annotations

from lean_spec.node.networking.varint import encode_varint

from .constants import (
BLOCK_SIZE,
HASH_MULTIPLIER,
INPUT_MARGIN_BYTES,
MAX_HASH_TABLE_BITS,
MIN_HASH_TABLE_BITS,
SNAPPY_VARINT_MAX_BYTES,
)
from .encoding import encode_copy_tag, encode_literal_tag, encode_varint32
from .encoding import encode_copy_tag, encode_literal_tag


def compress(data: bytes) -> bytes:
Expand All @@ -91,13 +94,13 @@ def compress(data: bytes) -> bytes:
#
# Even empty data needs a length prefix (varint 0).
if not data:
return encode_varint32(0)
return encode_varint(0, max_bytes=SNAPPY_VARINT_MAX_BYTES)

# Build output buffer.
#
# Start with the uncompressed length as a varint.
# The decompressor reads this first to allocate the output buffer.
output = bytearray(encode_varint32(len(data)))
output = bytearray(encode_varint(len(data), max_bytes=SNAPPY_VARINT_MAX_BYTES))

# Process input in blocks.
#
Expand Down
14 changes: 8 additions & 6 deletions src/lean_spec/node/snappy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,13 @@

#
# The uncompressed length is encoded as a varint at the start of the
# compressed data. Varints use 7 bits per byte, with the high bit
# indicating continuation.
# compressed data. The shared LEB128 codec from the networking layer
# handles encoding and decoding. The cap below bounds the prefix length
# to the 32-bit range defined by the Snappy format.

VARINT_CONTINUATION_BIT: Final = 0x80
"""High bit set in varint bytes to indicate more bytes follow."""
SNAPPY_VARINT_MAX_BYTES: Final = 5
"""Maximum byte count for the uncompressed length prefix.

VARINT_DATA_MASK: Final = 0x7F
"""Mask to extract the 7 data bits from a varint byte."""
Five bytes carry thirty-five data bits.
This is the smallest LEB128 length that covers the full 32-bit range
used by the Snappy format for the uncompressed payload size."""
11 changes: 8 additions & 3 deletions src/lean_spec/node/snappy/decompress.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@

from __future__ import annotations

from .encoding import decode_tag, decode_varint32
from lean_spec.node.networking.varint import VarintError, decode_varint

from .constants import SNAPPY_VARINT_MAX_BYTES
from .encoding import decode_tag


class SnappyDecompressionError(Exception):
Expand Down Expand Up @@ -100,8 +103,10 @@ def decompress(data: bytes) -> bytes:
#
# Example: data = [0x08, ...] -> length = 8
try:
uncompressed_length, varint_bytes = decode_varint32(data, 0)
except ValueError as e:
uncompressed_length, varint_bytes = decode_varint(
data, 0, max_bytes=SNAPPY_VARINT_MAX_BYTES
)
except VarintError as e:
raise SnappyDecompressionError(f"Invalid length varint: {e}") from e

# Length = 0 is valid: the original data was empty.
Expand Down
159 changes: 5 additions & 154 deletions src/lean_spec/node/snappy/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
This module provides the low-level encoding and decoding primitives used by
both the compressor and decompressor:

1. **Varint encoding**: Variable-length integer encoding for the uncompressed
length prefix. Small values use fewer bytes, saving space.

2. **Tag byte encoding**: Compact representation of literal and copy operations.
1. **Tag byte encoding**: Compact representation of literal and copy operations.
The tag byte format packs operation type and length into minimal space.

The uncompressed length prefix uses the shared LEB128 varint codec from the
networking layer with a five-byte cap, matching the 32-bit range used by the
Snappy format specification.

Reference: https://github.com/google/snappy/blob/main/format_description.txt
"""

Expand All @@ -31,161 +32,11 @@
MAX_COPY_2_OFFSET,
MAX_INLINE_LITERAL_LENGTH,
MIN_COPY_1_LENGTH,
VARINT_CONTINUATION_BIT,
VARINT_DATA_MASK,
)

type TagType = Literal["literal", "copy"]
"""Snappy tag type: either a literal (raw bytes) or a copy (back-reference)."""

# Varint Encoding
#
# Varints encode integers using as few bytes as possible.
# - Small values use fewer bytes.
# - Large values use more.
#
# Each byte has 8 bits:
# - Bit 7 (high): continuation flag.
# - 1 = more bytes follow,
# - 0 = this is the last byte.
# - Bits 0-6 (low): 7 bits of the integer value.
#
# Bytes are emitted least-significant chunk first.
#
# Byte count by value:
# 0 .. 127 -> 1 byte
# 128 .. 16,383 -> 2 bytes
# 16,384 .. 2,097,151 -> 3 bytes
# 2,097,152 .. 268,435,455 -> 4 bytes
# 268,435,456 .. 2^32 - 1 -> 5 bytes
#
# Example: encoding 300
#
# 300 in binary: 100101100 (9 bits, needs 2 chunks of 7 bits)
#
# Chunk 1 (bits 0-6): 0101100 = 44. More bits remain, so continuation = 1.
# Byte 1 = 0x80 | 44 = 0xAC
#
# Chunk 2 (bits 7+): 0000010 = 2. No more bits, so continuation = 0.
# Byte 2 = 0x00 | 2 = 0x02
#
# Encoded: [0xAC, 0x02]
#
# Example: decoding [0xAC, 0x02]
#
# For each byte: check bit 7 for continuation, mask with 0x7F to get data.
#
# Byte 1 = 0xAC = 10101100:
# bit 7 = 1 -> more bytes coming
# data = 0xAC & 0x7F = 0101100 = 44
# result = 44
#
# Byte 2 = 0x02 = 00000010:
# bit 7 = 0 -> done
# data = 0x02 & 0x7F = 0000010 = 2 (mask has no effect here)
# result = 44 | (2 << 7) = 44 + 256 = 300


def encode_varint32(value: int) -> bytes:
"""Encode a 32-bit integer as a variable-length byte sequence.

The varint format uses 7 bits per byte for data, with the high bit
indicating whether more bytes follow. This efficiently encodes small
values in fewer bytes.

Algorithm:
1. Take the lowest 7 bits of the value.
2. If more bits remain, set the continuation bit (0x80).
3. Repeat until all bits are encoded.

Args:
value: Non-negative integer to encode (must fit in 32 bits).

Returns:
Variable-length bytes encoding the integer (1-5 bytes).

Raises:
ValueError: If value is negative or exceeds 32 bits.
"""
# Validate input range.
# Varints in Snappy are unsigned 32-bit integers.
if value < 0:
raise ValueError(f"Varint value must be non-negative, got {value}")
if value > 0xFFFFFFFF:
raise ValueError(f"Varint value exceeds 32 bits: {value}")

# Build the encoding byte by byte.
# We accumulate bytes in a list for efficiency.
result: list[int] = []

while True:
# Extract the lowest 7 bits.
byte = value & VARINT_DATA_MASK

# Shift out the bits we just encoded.
value >>= 7

if value != 0:
# More bits remain: set continuation bit.
byte |= VARINT_CONTINUATION_BIT

result.append(byte)

if value == 0:
# All bits encoded.
break

return bytes(result)


def decode_varint32(data: bytes, offset: int = 0) -> tuple[int, int]:
"""Decode a varint from a byte sequence at the given offset.

Reads bytes starting at offset, accumulating 7 bits per byte into
the result. Stops when a byte without the continuation bit is found.

Args:
data: Byte sequence containing the varint.
offset: Position in data where the varint starts.

Returns:
Tuple of (decoded_value, bytes_consumed).

Raises:
ValueError: If the varint is malformed (too long or truncated).
"""
result = 0
shift = 0
bytes_read = 0

while True:
# Check bounds.
if offset + bytes_read >= len(data):
raise ValueError("Truncated varint: unexpected end of data")

# Read next byte.
byte = data[offset + bytes_read]
bytes_read += 1

# Accumulate the 7 data bits at the current shift position.
result |= (byte & VARINT_DATA_MASK) << shift
shift += 7

# Check if this is the last byte (no continuation bit).
if (byte & VARINT_CONTINUATION_BIT) == 0:
break

# Safety check: varints should not exceed 5 bytes for 32-bit values.
# (5 bytes * 7 bits = 35 bits, which covers 32-bit range)
if bytes_read >= 5:
raise ValueError("Varint too long: exceeds 5 bytes")

# Verify the result fits in 32 bits.
if result > 0xFFFFFFFF:
raise ValueError(f"Varint overflow: {result} exceeds 32 bits")

return result, bytes_read


# Tag Byte Encoding - Literals
#
Expand Down
2 changes: 1 addition & 1 deletion tests/lean_spec/node/networking/test_reqresp.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_varint_11_bytes_rejected(self) -> None:
malformed = bytes([0x80] * 10 + [0x01])
assert len(malformed) == 11

with pytest.raises(VarintError, match="too long"):
with pytest.raises(VarintError, match="exceeds 10 bytes"):
decode_varint(malformed)

def test_payload_at_max_size(self) -> None:
Expand Down
Loading
Loading