Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
from logging import getLogger
from re import split
from typing import Iterable, List, Mapping, Optional, Set
from typing import Generator, Iterable, Mapping, Optional, Set
from urllib.parse import quote_plus, unquote_plus

from opentelemetry.baggage import _is_valid_pair, get_all, set_baggage
Expand All @@ -26,6 +26,50 @@
_logger = getLogger(__name__)


def _apply_baggage_limits(
entries: Iterable[str],
max_pairs: int,
max_pair_length: int,
max_header_length: int,
) -> Generator[str, None, None]:
"""Apply W3C Baggage size limits to a sequence of baggage entries.

Yields entries that fit within the W3C specification limits.
Logs warnings when entries are dropped.
"""
count = 0
total_length = 0

for entry in entries:
if not entry:
continue

if len(entry) > max_pair_length:
_logger.warning(
"Baggage entry `%s` exceeded the maximum number of bytes per list-member",
entry,
)
continue

if count >= max_pairs:
_logger.warning(
"Baggage exceeded the maximum number of list-members",
)
break

# Account for comma separator between entries
added_length = len(entry) + (1 if count > 0 else 0)
if total_length + added_length > max_header_length:
_logger.warning(
"Baggage exceeded the maximum number of bytes per baggage-string",
)
break

count += 1
total_length += added_length
yield entry


class W3CBaggagePropagator(textmap.TextMapPropagator):
"""Extracts and injects Baggage which is used to annotate telemetry."""

Expand Down Expand Up @@ -63,24 +107,20 @@ def extract(
)
return context

baggage_entries: List[str] = split(_DELIMITER_PATTERN, header)
total_baggage_entries = self._MAX_PAIRS
baggage_entries = split(_DELIMITER_PATTERN, header)

if len(baggage_entries) > self._MAX_PAIRS:
_logger.warning(
"Baggage header `%s` exceeded the maximum number of list-members",
header,
)

for entry in baggage_entries:
if len(entry) > self._MAX_PAIR_LENGTH:
_logger.warning(
"Baggage entry `%s` exceeded the maximum number of bytes per list-member",
entry,
)
continue
if not entry: # empty string
continue
for entry in _apply_baggage_limits(
baggage_entries,
max_pairs=self._MAX_PAIRS,
max_pair_length=self._MAX_PAIR_LENGTH,
max_header_length=self._MAX_HEADER_LENGTH,
):
try:
name, value = entry.split("=", 1)
except Exception: # pylint: disable=broad-exception-caught
Expand All @@ -101,9 +141,6 @@ def extract(
value,
context=context,
)
total_baggage_entries -= 1
if total_baggage_entries == 0:
break

return context

Expand All @@ -122,8 +159,17 @@ def inject(
if not baggage_entries:
return

baggage_string = _format_baggage(baggage_entries)
setter.set(carrier, self._BAGGAGE_HEADER_NAME, baggage_string)
baggage_string = ",".join(
_apply_baggage_limits(
_encode_baggage_pairs(baggage_entries),
max_pairs=self._MAX_PAIRS,
max_pair_length=self._MAX_PAIR_LENGTH,
max_header_length=self._MAX_HEADER_LENGTH,
)
)

if baggage_string:
setter.set(carrier, self._BAGGAGE_HEADER_NAME, baggage_string)

@property
def fields(self) -> Set[str]:
Expand All @@ -132,10 +178,15 @@ def fields(self) -> Set[str]:


def _format_baggage(baggage_entries: Mapping[str, object]) -> str:
return ",".join(
quote_plus(str(key)) + "=" + quote_plus(str(value))
for key, value in baggage_entries.items()
)
return ",".join(_encode_baggage_pairs(baggage_entries))


def _encode_baggage_pairs(
baggage_entries: Mapping[str, object],
) -> Generator[str, None, None]:
"""Yield URL-encoded 'key=value' pairs from baggage entries."""
for key, value in baggage_entries.items():
yield quote_plus(str(key)) + "=" + quote_plus(str(value))


def _extract_first_element(
Expand Down
114 changes: 101 additions & 13 deletions opentelemetry-api/tests/propagators/test_w3cbaggagepropagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,21 @@ def test_header_too_long(self):
long_value = "s" * (W3CBaggagePropagator._MAX_HEADER_LENGTH + 1)
header = f"key1={long_value}"
expected = {}
self.assertEqual(self._extract(header), expected)
with self.assertLogs(level=WARNING) as warning:
self.assertEqual(self._extract(header), expected)
self.assertIn(
"exceeded the maximum number of bytes per baggage-string",
warning.output[0],
)

def test_header_contains_too_many_entries(self):
header = ",".join(
[f"key{k}=val" for k in range(W3CBaggagePropagator._MAX_PAIRS + 1)]
)
self.assertEqual(
len(self._extract(header)), W3CBaggagePropagator._MAX_PAIRS
)
with self.assertLogs(level=WARNING):
self.assertEqual(
len(self._extract(header)), W3CBaggagePropagator._MAX_PAIRS
)

def test_header_contains_pair_too_long(self):
long_value = "s" * (W3CBaggagePropagator._MAX_PAIR_LENGTH + 1)
Expand All @@ -130,6 +136,7 @@ def test_extract_unquote_plus(self):
)

def test_header_max_entries_skip_invalid_entry(self):
# 181 entries where index 2 is too long: skipping it leaves exactly 180 valid entries
with self.assertLogs(level=WARNING) as warning:
self.assertEqual(
self._extract(
Expand All @@ -155,11 +162,17 @@ def test_header_max_entries_skip_invalid_entry(self):
if index != 2
},
)
self.assertIn(
"exceeded the maximum number of list-members",
warning.output[0],
self.assertTrue(
any(
"exceeded the maximum number of bytes per list-member"
in msg
for msg in warning.output
)
)

# 181 entries where index 2 is malformed (no '='): _apply_baggage_limits
# accepts the first 180 entries (indices 0-179), then the malformed entry
# at index 2 is skipped during parsing
with self.assertLogs(level=WARNING) as warning:
self.assertEqual(
self._extract(
Expand All @@ -178,13 +191,20 @@ def test_header_max_entries_skip_invalid_entry(self):
),
{
f"key{index}": f"value{index}"
for index in range(W3CBaggagePropagator._MAX_PAIRS + 1)
for index in range(W3CBaggagePropagator._MAX_PAIRS)
if index != 2
},
)
self.assertIn(
"exceeded the maximum number of list-members",
warning.output[0],
self.assertTrue(
any(
"exceeded the maximum number of list-members" in msg
for msg in warning.output
)
)
self.assertTrue(
any(
"doesn't match the format" in msg for msg in warning.output
)
)

def test_inject_no_baggage_entries(self):
Expand Down Expand Up @@ -224,8 +244,8 @@ def test_inject_non_string_values(self):
self.assertIn("key3=123.567", output)

@patch("opentelemetry.baggage.propagation.get_all")
@patch("opentelemetry.baggage.propagation._format_baggage")
def test_fields(self, mock_format_baggage, mock_baggage):
def test_fields(self, mock_baggage):
mock_baggage.return_value = {"key": "value"}
mock_setter = Mock()

self.propagator.inject({}, setter=mock_setter)
Expand All @@ -246,6 +266,74 @@ def test__format_baggage(self):
"key%2Fkey=value%2Fvalue",
)

def test_inject_too_many_entries(self):
"""Inject should drop entries exceeding _MAX_PAIRS."""
values = {
f"key{i}": f"val{i}"
for i in range(self.propagator._MAX_PAIRS + 10)
}
ctx = get_current()
for k, v in values.items():
ctx = set_baggage(k, v, context=ctx)
output = {}
with self.assertLogs(level=WARNING) as warning:
self.propagator.inject(output, context=ctx)
self.assertIn(
"exceeded the maximum number of list-members",
warning.output[0],
)

def test_inject_entry_too_long(self):
"""Inject should skip individual entries that exceed _MAX_PAIR_LENGTH."""
long_value = "x" * self.propagator._MAX_PAIR_LENGTH
values = {"key1": "val1", "big": long_value, "key3": "val3"}
ctx = get_current()
for k, v in values.items():
ctx = set_baggage(k, v, context=ctx)
output = {}
with self.assertLogs(level=WARNING) as warning:
self.propagator.inject(output, context=ctx)
self.assertIn(
"exceeded the maximum number of bytes per list-member",
warning.output[0],
)
baggage_str = output.get("baggage", "")
self.assertIn("key1=val1", baggage_str)
self.assertIn("key3=val3", baggage_str)
self.assertNotIn("big=", baggage_str)

def test_inject_total_header_too_long(self):
"""Inject should stop adding entries when total header would exceed _MAX_HEADER_LENGTH."""
# Create entries that individually fit but collectively exceed the max header length
# Each entry "kNNN=<value>" with value ~200 chars; 50 of these > 8192
value = "v" * 200
values = {f"k{i:03d}": value for i in range(50)}
ctx = get_current()
for k, v in values.items():
ctx = set_baggage(k, v, context=ctx)
output = {}
with self.assertLogs(level=WARNING) as warning:
self.propagator.inject(output, context=ctx)
self.assertIn(
"exceeded the maximum number of bytes per baggage-string",
warning.output[0],
)
baggage_str = output.get("baggage", "")
self.assertLessEqual(
len(baggage_str), self.propagator._MAX_HEADER_LENGTH
)

def test_inject_empty_after_all_dropped(self):
"""If all entries are too long, nothing should be injected."""
long_value = "x" * self.propagator._MAX_PAIR_LENGTH
values = {"big1": long_value, "big2": long_value}
ctx = get_current()
for k, v in values.items():
ctx = set_baggage(k, v, context=ctx)
output = {}
self.propagator.inject(output, context=ctx)
self.assertNotIn("baggage", output)

@patch("opentelemetry.baggage._BAGGAGE_KEY", new="abc")
def test_inject_extract(self):
carrier = {}
Expand Down
Loading