diff --git a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py index 49fb378eabd..1ce71c8c102 100644 --- a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py @@ -14,7 +14,7 @@ # from logging import getLogger from re import split -from typing import Iterable, List, Mapping, Optional, Set +from typing import Generator, Iterable, Mapping, Optional, Set from urllib.parse import quote_plus, unquote_plus from opentelemetry.baggage import _is_valid_pair, get_all, set_baggage @@ -26,6 +26,50 @@ _logger = getLogger(__name__) +def _apply_baggage_limits( + entries: Iterable[str], + max_pairs: int, + max_pair_length: int, + max_header_length: int, +) -> Generator[str, None, None]: + """Apply W3C Baggage size limits to a sequence of baggage entries. + + Yields entries that fit within the W3C specification limits. + Logs warnings when entries are dropped. + """ + count = 0 + total_length = 0 + + for entry in entries: + if not entry: + continue + + if len(entry) > max_pair_length: + _logger.warning( + "Baggage entry `%s` exceeded the maximum number of bytes per list-member", + entry, + ) + continue + + if count >= max_pairs: + _logger.warning( + "Baggage exceeded the maximum number of list-members", + ) + break + + # Account for comma separator between entries + added_length = len(entry) + (1 if count > 0 else 0) + if total_length + added_length > max_header_length: + _logger.warning( + "Baggage exceeded the maximum number of bytes per baggage-string", + ) + break + + count += 1 + total_length += added_length + yield entry + + class W3CBaggagePropagator(textmap.TextMapPropagator): """Extracts and injects Baggage which is used to annotate telemetry.""" @@ -63,8 +107,7 @@ def extract( ) return context - baggage_entries: List[str] = split(_DELIMITER_PATTERN, header) - total_baggage_entries = self._MAX_PAIRS + baggage_entries = split(_DELIMITER_PATTERN, header) if len(baggage_entries) > self._MAX_PAIRS: _logger.warning( @@ -72,15 +115,12 @@ def extract( header, ) - for entry in baggage_entries: - if len(entry) > self._MAX_PAIR_LENGTH: - _logger.warning( - "Baggage entry `%s` exceeded the maximum number of bytes per list-member", - entry, - ) - continue - if not entry: # empty string - continue + for entry in _apply_baggage_limits( + baggage_entries, + max_pairs=self._MAX_PAIRS, + max_pair_length=self._MAX_PAIR_LENGTH, + max_header_length=self._MAX_HEADER_LENGTH, + ): try: name, value = entry.split("=", 1) except Exception: # pylint: disable=broad-exception-caught @@ -101,9 +141,6 @@ def extract( value, context=context, ) - total_baggage_entries -= 1 - if total_baggage_entries == 0: - break return context @@ -122,8 +159,17 @@ def inject( if not baggage_entries: return - baggage_string = _format_baggage(baggage_entries) - setter.set(carrier, self._BAGGAGE_HEADER_NAME, baggage_string) + baggage_string = ",".join( + _apply_baggage_limits( + _encode_baggage_pairs(baggage_entries), + max_pairs=self._MAX_PAIRS, + max_pair_length=self._MAX_PAIR_LENGTH, + max_header_length=self._MAX_HEADER_LENGTH, + ) + ) + + if baggage_string: + setter.set(carrier, self._BAGGAGE_HEADER_NAME, baggage_string) @property def fields(self) -> Set[str]: @@ -132,10 +178,15 @@ def fields(self) -> Set[str]: def _format_baggage(baggage_entries: Mapping[str, object]) -> str: - return ",".join( - quote_plus(str(key)) + "=" + quote_plus(str(value)) - for key, value in baggage_entries.items() - ) + return ",".join(_encode_baggage_pairs(baggage_entries)) + + +def _encode_baggage_pairs( + baggage_entries: Mapping[str, object], +) -> Generator[str, None, None]: + """Yield URL-encoded 'key=value' pairs from baggage entries.""" + for key, value in baggage_entries.items(): + yield quote_plus(str(key)) + "=" + quote_plus(str(value)) def _extract_first_element( diff --git a/opentelemetry-api/tests/propagators/test_w3cbaggagepropagator.py b/opentelemetry-api/tests/propagators/test_w3cbaggagepropagator.py index 46db45f4d34..7c52c59cf12 100644 --- a/opentelemetry-api/tests/propagators/test_w3cbaggagepropagator.py +++ b/opentelemetry-api/tests/propagators/test_w3cbaggagepropagator.py @@ -99,15 +99,21 @@ def test_header_too_long(self): long_value = "s" * (W3CBaggagePropagator._MAX_HEADER_LENGTH + 1) header = f"key1={long_value}" expected = {} - self.assertEqual(self._extract(header), expected) + with self.assertLogs(level=WARNING) as warning: + self.assertEqual(self._extract(header), expected) + self.assertIn( + "exceeded the maximum number of bytes per baggage-string", + warning.output[0], + ) def test_header_contains_too_many_entries(self): header = ",".join( [f"key{k}=val" for k in range(W3CBaggagePropagator._MAX_PAIRS + 1)] ) - self.assertEqual( - len(self._extract(header)), W3CBaggagePropagator._MAX_PAIRS - ) + with self.assertLogs(level=WARNING): + self.assertEqual( + len(self._extract(header)), W3CBaggagePropagator._MAX_PAIRS + ) def test_header_contains_pair_too_long(self): long_value = "s" * (W3CBaggagePropagator._MAX_PAIR_LENGTH + 1) @@ -130,6 +136,7 @@ def test_extract_unquote_plus(self): ) def test_header_max_entries_skip_invalid_entry(self): + # 181 entries where index 2 is too long: skipping it leaves exactly 180 valid entries with self.assertLogs(level=WARNING) as warning: self.assertEqual( self._extract( @@ -155,11 +162,17 @@ def test_header_max_entries_skip_invalid_entry(self): if index != 2 }, ) - self.assertIn( - "exceeded the maximum number of list-members", - warning.output[0], + self.assertTrue( + any( + "exceeded the maximum number of bytes per list-member" + in msg + for msg in warning.output + ) ) + # 181 entries where index 2 is malformed (no '='): _apply_baggage_limits + # accepts the first 180 entries (indices 0-179), then the malformed entry + # at index 2 is skipped during parsing with self.assertLogs(level=WARNING) as warning: self.assertEqual( self._extract( @@ -178,13 +191,20 @@ def test_header_max_entries_skip_invalid_entry(self): ), { f"key{index}": f"value{index}" - for index in range(W3CBaggagePropagator._MAX_PAIRS + 1) + for index in range(W3CBaggagePropagator._MAX_PAIRS) if index != 2 }, ) - self.assertIn( - "exceeded the maximum number of list-members", - warning.output[0], + self.assertTrue( + any( + "exceeded the maximum number of list-members" in msg + for msg in warning.output + ) + ) + self.assertTrue( + any( + "doesn't match the format" in msg for msg in warning.output + ) ) def test_inject_no_baggage_entries(self): @@ -224,8 +244,8 @@ def test_inject_non_string_values(self): self.assertIn("key3=123.567", output) @patch("opentelemetry.baggage.propagation.get_all") - @patch("opentelemetry.baggage.propagation._format_baggage") - def test_fields(self, mock_format_baggage, mock_baggage): + def test_fields(self, mock_baggage): + mock_baggage.return_value = {"key": "value"} mock_setter = Mock() self.propagator.inject({}, setter=mock_setter) @@ -246,6 +266,74 @@ def test__format_baggage(self): "key%2Fkey=value%2Fvalue", ) + def test_inject_too_many_entries(self): + """Inject should drop entries exceeding _MAX_PAIRS.""" + values = { + f"key{i}": f"val{i}" + for i in range(self.propagator._MAX_PAIRS + 10) + } + ctx = get_current() + for k, v in values.items(): + ctx = set_baggage(k, v, context=ctx) + output = {} + with self.assertLogs(level=WARNING) as warning: + self.propagator.inject(output, context=ctx) + self.assertIn( + "exceeded the maximum number of list-members", + warning.output[0], + ) + + def test_inject_entry_too_long(self): + """Inject should skip individual entries that exceed _MAX_PAIR_LENGTH.""" + long_value = "x" * self.propagator._MAX_PAIR_LENGTH + values = {"key1": "val1", "big": long_value, "key3": "val3"} + ctx = get_current() + for k, v in values.items(): + ctx = set_baggage(k, v, context=ctx) + output = {} + with self.assertLogs(level=WARNING) as warning: + self.propagator.inject(output, context=ctx) + self.assertIn( + "exceeded the maximum number of bytes per list-member", + warning.output[0], + ) + baggage_str = output.get("baggage", "") + self.assertIn("key1=val1", baggage_str) + self.assertIn("key3=val3", baggage_str) + self.assertNotIn("big=", baggage_str) + + def test_inject_total_header_too_long(self): + """Inject should stop adding entries when total header would exceed _MAX_HEADER_LENGTH.""" + # Create entries that individually fit but collectively exceed the max header length + # Each entry "kNNN=" with value ~200 chars; 50 of these > 8192 + value = "v" * 200 + values = {f"k{i:03d}": value for i in range(50)} + ctx = get_current() + for k, v in values.items(): + ctx = set_baggage(k, v, context=ctx) + output = {} + with self.assertLogs(level=WARNING) as warning: + self.propagator.inject(output, context=ctx) + self.assertIn( + "exceeded the maximum number of bytes per baggage-string", + warning.output[0], + ) + baggage_str = output.get("baggage", "") + self.assertLessEqual( + len(baggage_str), self.propagator._MAX_HEADER_LENGTH + ) + + def test_inject_empty_after_all_dropped(self): + """If all entries are too long, nothing should be injected.""" + long_value = "x" * self.propagator._MAX_PAIR_LENGTH + values = {"big1": long_value, "big2": long_value} + ctx = get_current() + for k, v in values.items(): + ctx = set_baggage(k, v, context=ctx) + output = {} + self.propagator.inject(output, context=ctx) + self.assertNotIn("baggage", output) + @patch("opentelemetry.baggage._BAGGAGE_KEY", new="abc") def test_inject_extract(self): carrier = {}