diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py index 1652220dfc..cee329bb0a 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py @@ -369,6 +369,55 @@ def get_description(self): return f"ParentBased{{root:{self._root.get_description()},remoteParentSampled:{self._remote_parent_sampled.get_description()},remoteParentNotSampled:{self._remote_parent_not_sampled.get_description()},localParentSampled:{self._local_parent_sampled.get_description()},localParentNotSampled:{self._local_parent_not_sampled.get_description()}}}" +class AlwaysRecordSampler(Sampler): + """ + This sampler will return the sampling result of the provided `root`, unless the + sampling result contains the sampling decision `Decision.DROP`, in which case, a + new sampling result will be returned that is functionally equivalent to the original, except that + it contains the sampling decision `Decision.RECORD_ONLY`. This ensures that all + spans are recorded, with no change to sampling. + + The intended use case of this sampler is to provide a means of sending all spans to a + processor without having an impact on the sampling rate. This may be desirable if a user wishes + to count or otherwise measure all spans produced in a service, without incurring the cost of 100% + sampling. + """ + + def __init__(self, root: Sampler): + if root is None: + raise ValueError("root must not be None") + self._root = root + + def should_sample( + self, + parent_context: Context | None, + trace_id: int, + name: str, + kind: SpanKind | None = None, + attributes: Attributes = None, + links: Sequence[Link] | None = None, + trace_state: TraceState | None = None, + ) -> SamplingResult: + result: SamplingResult = self._root.should_sample( + parent_context, + trace_id, + name, + kind, + attributes, + links, + trace_state, + ) + if result.decision is Decision.DROP: + result = SamplingResult( + Decision.RECORD_ONLY, result.attributes, result.trace_state + ) + + return result + + def get_description(self): + return f"AlwaysRecordSampler{{{self._root.get_description()}}}" + + DEFAULT_OFF = ParentBased(ALWAYS_OFF) """Sampler that respects its parent span's sampling decision, but otherwise never samples.""" diff --git a/opentelemetry-sdk/tests/trace/test_sampling.py b/opentelemetry-sdk/tests/trace/test_sampling.py index 1d33a1a2c2..6a5c830f5e 100644 --- a/opentelemetry-sdk/tests/trace/test_sampling.py +++ b/opentelemetry-sdk/tests/trace/test_sampling.py @@ -4,6 +4,7 @@ import contextlib import sys import unittest +import unittest.mock from opentelemetry import context as context_api from opentelemetry import trace @@ -524,3 +525,66 @@ def implicit_parent_context(span: trace.Span): context_api.detach(token) self.exec_parent_based(implicit_parent_context) + + +class TestAlwaysRecordSampler(unittest.TestCase): + def setUp(self): + self.mock_sampler: sampling.Sampler = unittest.mock.MagicMock() + self.sampler: sampling.Sampler = sampling.AlwaysRecordSampler( + self.mock_sampler + ) + + def test_get_description(self): + static_sampler: sampling.Sampler = sampling.StaticSampler( + sampling.Decision.DROP + ) + test_sampler: sampling.Sampler = sampling.AlwaysRecordSampler( + static_sampler + ) + self.assertEqual( + "AlwaysRecordSampler{AlwaysOffSampler}", + test_sampler.get_description(), + ) + + def test_record_and_sample_sampling_decision(self): + self.validate_should_sample( + sampling.Decision.RECORD_AND_SAMPLE, + sampling.Decision.RECORD_AND_SAMPLE, + ) + + def test_record_only_sampling_decision(self): + self.validate_should_sample( + sampling.Decision.RECORD_ONLY, sampling.Decision.RECORD_ONLY + ) + + def test_drop_sampling_decision(self): + self.validate_should_sample( + sampling.Decision.DROP, sampling.Decision.RECORD_ONLY + ) + + def validate_should_sample( + self, + root_decision: sampling.Decision, + expected_decision: sampling.Decision, + ): + trace_state: trace.TraceState = trace.TraceState() + trace_state.add("key", root_decision.name) + root_result: sampling.SamplingResult = sampling.SamplingResult( + attributes={"key", root_decision.name}, + decision=root_decision, + trace_state=trace_state, + ) + self.mock_sampler.should_sample.return_value = root_result + + actual_result: sampling.SamplingResult = self.sampler.should_sample( + parent_context=context_api.Context(), + trace_id=0, + name="name", + kind=trace.SpanKind.CLIENT, + attributes={"key": root_decision.name}, + trace_state=trace.TraceState(), + ) + + self.assertEqual(actual_result.decision, expected_decision) + self.assertEqual(actual_result.attributes, root_result.attributes) + self.assertEqual(actual_result.trace_state, root_result.trace_state)