Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions databento/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class MappingIntervalDict(TypedDict):


RecordCallback = Callable[[databento_dbn.DBNRecord], None]
RawRecordCallback = Callable[[bytes], None]
ExceptionCallback = Callable[[Exception], None]
ReconnectCallback = Callable[[pd.Timestamp, pd.Timestamp], None]

Expand Down Expand Up @@ -262,3 +263,19 @@ def _warn(self, msg: str) -> None:
BentoWarning,
stacklevel=3,
)


class ClientRawRecordCallback(ClientRecordCallback):
def __init__(
self,
fn: RawRecordCallback,
exc_fn: ExceptionCallback | None = None,
max_warnings: int = 10,
) -> None:
super().__init__(fn=fn, exc_fn=exc_fn, max_warnings=max_warnings) # type: ignore [arg-type]

def call(self, raw: bytes) -> None: # type: ignore [override]
"""
Execute the callback, passing raw bytes.
"""
super().call(raw) # type: ignore [arg-type]
37 changes: 37 additions & 0 deletions databento/live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from databento.common.error import BentoError
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
from databento.common.publishers import Dataset
from databento.common.types import ClientRawRecordCallback
from databento.common.types import ClientRecordCallback
from databento.common.types import ClientStream
from databento.common.types import ExceptionCallback
from databento.common.types import RawRecordCallback
from databento.common.types import ReconnectCallback
from databento.common.types import RecordCallback
from databento.common.validation import validate_enum
Expand Down Expand Up @@ -354,6 +356,41 @@ def add_callback(
logger.info("adding user callback %s", client_callback.callback_name)
self._session._user_callbacks.append(client_callback)

def add_raw_callback(
self,
record_callback: RawRecordCallback,
exception_callback: ExceptionCallback | None = None,
) -> None:
"""
Add a callback for handling records as raw bytes.

Unlike `add_callback`, this receives each record as raw `bytes`.
No Python objects are created, avoiding overhead and memory issues.

Parameters
----------
record_callback : Callable[[bytes], None]
A callback to register for handling live records as raw bytes.
exception_callback : Callable[[Exception], None], optional
An error handling callback for exceptions raised in `record_callback`.

Raises
------
ValueError
If `record_callback` is not callable.

See Also
--------
Live.add_callback

"""
client_callback = ClientRawRecordCallback(
fn=record_callback,
exc_fn=exception_callback,
)
logger.info("adding raw callback %s", client_callback.callback_name)
self._session._raw_callbacks.append(client_callback)

def add_stream(
self,
stream: IO[bytes] | PathLike[str] | str,
Expand Down
66 changes: 40 additions & 26 deletions databento/live/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,32 +383,46 @@ def _process_dbn(self, data: bytes) -> None:
raise
else:
for record in records:
logger.debug("dispatching %s", type(record).__name__)
if isinstance(record, databento_dbn.Metadata):
self.received_metadata(record)
continue
if isinstance(record, databento_dbn.ErrorMsg):
logger.error(
"gateway error code=%s err='%s'",
record.code,
record.err,
)
self._error_msgs.append(record.err)
elif isinstance(record, databento_dbn.SystemMsg):
if record.is_heartbeat():
logger.debug("gateway heartbeat")
else:
if record.code == SystemCode.END_OF_INTERVAL:
system_msg_level = logging.DEBUG
else:
system_msg_level = logging.INFO
logger.log(
system_msg_level,
"system message code=%s msg='%s'",
record.code,
record.msg,
)
self.received_record(record)
self._dispatch_decoded_record(record)

def _dispatch_decoded_record(self, record: DBNRecord | Metadata) -> None:
"""
Route a single decoded record to the appropriate handler.
"""
logger.debug("dispatching %s", type(record).__name__)
if isinstance(record, databento_dbn.Metadata):
self.received_metadata(record)
else:
self._handle_control_record(record)
self.received_record(record)

def _handle_control_record(self, record: DBNRecord) -> None:
"""
Process control record side effects: logging and error tracking.

Called for ErrorMsg and SystemMsg before received_record().
"""
if isinstance(record, databento_dbn.ErrorMsg):
logger.error(
"gateway error code=%s err='%s'",
record.code,
record.err,
)
self._error_msgs.append(record.err)
elif isinstance(record, databento_dbn.SystemMsg):
if record.is_heartbeat():
logger.debug("gateway heartbeat")
else:
if record.code == SystemCode.END_OF_INTERVAL:
system_msg_level = logging.DEBUG
else:
system_msg_level = logging.INFO
logger.log(
system_msg_level,
"system message code=%s msg='%s'",
record.code,
record.msg,
)

def _process_gateway(self, data: bytes) -> None:
try:
Expand Down
40 changes: 40 additions & 0 deletions databento/live/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
from databento.common.enums import SlowReaderBehavior
from databento.common.error import BentoError
from databento.common.publishers import Dataset
from databento.common.types import ClientRawRecordCallback
from databento.common.types import ClientRecordCallback
from databento.common.types import ClientStream
from databento.common.types import ExceptionCallback
from databento.common.types import RawRecordCallback
from databento.common.types import ReconnectCallback
from databento.live.gateway import SubscriptionRequest
from databento.live.protocol import DatabentoLiveProtocol
Expand Down Expand Up @@ -209,6 +211,7 @@ def __init__(
heartbeat_interval_s: int | None = None,
slow_reader_behavior: SlowReaderBehavior | str | None = None,
compression: Compression = Compression.NONE,
raw_callbacks: list[ClientRawRecordCallback] | None = None,
):
super().__init__(
api_key,
Expand All @@ -224,6 +227,7 @@ def __init__(
self._metadata: SessionMetadata = metadata
self._user_callbacks = user_callbacks
self._user_streams = user_streams
self._raw_callbacks: list[ClientRawRecordCallback] = raw_callbacks if raw_callbacks is not None else []
self._last_ts_event: int | None = None
self._last_msg_loop_time: float = math.inf

Expand Down Expand Up @@ -253,6 +257,40 @@ def received_record(self, record: DBNRecord) -> None:

return super().received_record(record)

def _process_dbn(self, data: bytes) -> None:
if not self._raw_callbacks:
return super()._process_dbn(data)

try:
self._dbn_decoder.write(bytes(data))
records = self._dbn_decoder.decode_raw()
except Exception:
logger.exception("error decoding DBN record")
self.transport.close()
raise

for record in records:
if isinstance(record, bytes):
# Data record as raw bytes, no Python object creation.
logger.debug("dispatching raw data record")
self._dispatch_raw_callbacks(record)
# ts_event lives at RecordHeader offset 8 (u64 LE).
self._last_ts_event = struct.unpack_from("<Q", record, 8)[0]
self._last_msg_loop_time = self._loop.time()
else:
self._dispatch_decoded_record(record)

def _dispatch_raw_callbacks(self, raw: bytes) -> None:
for callback in self._raw_callbacks:
try:
callback.call(raw)
except Exception as exc:
logger.error(
"error dispatching raw record to `%s` callback",
callback.callback_name,
exc_info=exc,
)

def _dispatch_callbacks(self, record: DBNRecord) -> None:
for callback in self._user_callbacks:
try:
Expand Down Expand Up @@ -336,6 +374,7 @@ def __init__(
self._user_gateway: str | None = user_gateway
self._user_streams: list[ClientStream] = []
self._user_callbacks: list[ClientRecordCallback] = []
self._raw_callbacks: list[ClientRawRecordCallback] = []
self._user_reconnect_callbacks: list[tuple[ReconnectCallback, ExceptionCallback | None]] = (
[]
)
Expand Down Expand Up @@ -598,6 +637,7 @@ def _create_protocol(self, dataset: Dataset | str) -> _SessionProtocol:
heartbeat_interval_s=self.heartbeat_interval_s,
slow_reader_behavior=self._slow_reader_behavior,
compression=self._compression,
raw_callbacks=self._raw_callbacks,
)

def _connect(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_live_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,42 @@ def callback(record: DBNRecord) -> None:
assert isinstance(records[3], databento_dbn.MBOMsg)


async def test_live_raw_callback(
live_client: client.Live,
) -> None:
"""
Test raw callback dispatch of DBN records as bytes.

Mirrors test_live_callback but uses add_raw_callback. Data records
should arrive as raw bytes; control records still go to add_callback.
"""
# Arrange
live_client.subscribe(
dataset=Dataset.GLBX_MDP3,
schema=Schema.MBO,
stype_in=SType.RAW_SYMBOL,
symbols="TEST",
)
raw_records: list[bytes] = []

def raw_callback(raw: bytes) -> None:
raw_records.append(raw)

# Act
live_client.add_raw_callback(raw_callback)

live_client.start()

await live_client.wait_for_close()

# Assert — same 4 MBO records, but as raw bytes
assert len(raw_records) == 4
mbo_size = len(bytes(databento_dbn.MBOMsg(0x01, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0)))
for raw in raw_records:
assert isinstance(raw, bytes)
assert len(raw) == mbo_size


@pytest.mark.parametrize(
"dataset",
[
Expand Down