From 54886ab66c7fcd5405bb66c4bffc669f129bf187 Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Sun, 26 Apr 2026 19:20:34 +0200 Subject: [PATCH 1/2] add: multy Topic support --- README.md | 56 ++++++++- taskiq_aio_kafka/__init__.py | 4 +- taskiq_aio_kafka/broker.py | 175 +++++++++++++++++++++++----- taskiq_aio_kafka/topic.py | 39 +++++++ tests/test_broker_multi_topic.py | 189 +++++++++++++++++++++++++++++++ 5 files changed, 430 insertions(+), 33 deletions(-) create mode 100644 taskiq_aio_kafka/topic.py create mode 100644 tests/test_broker_multi_topic.py diff --git a/README.md b/README.md index 3e8da0a..9a3d155 100644 --- a/README.md +++ b/README.md @@ -35,12 +35,64 @@ broker.configure_producer(request_timeout_ms=100000) broker.configure_consumer(group_id="the best group ever.") ``` +## Multiple topics + +By default `AioKafkaBroker` sends all tasks to `kafka_topic`. +You can also configure the broker to listen to multiple topics and bind +different tasks to different default topics. + +```python +from taskiq_aio_kafka import AioKafkaBroker +from taskiq_aio_kafka.topic import Topic + +broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_topics=[ + Topic("emails"), + Topic("reports"), + ], +) + + +@broker.task(topic="emails") +async def send_email(user_id: int) -> None: + print(f"Send email to {user_id}") + + +@broker.task(topic="reports") +async def build_report(report_id: int) -> None: + print(f"Build report {report_id}") +``` + +In this example the worker listens to `default-topic`, `emails`, and `reports`. +When you call `send_email.kiq(...)`, the task is sent to `emails` by default. +When you call `build_report.kiq(...)`, the task is sent to `reports` by default. + +You can override a task topic for a single kick with `kicker().with_topic(...)`: + +```python +await send_email.kicker().with_topic("reports").kiq(user_id=1) +``` + +Tasks without a custom topic keep the old behavior and are sent to `kafka_topic`. + +```python +@broker.task +async def regular_task() -> None: + print("This task goes to default-topic.") + + +await regular_task.kiq() +``` + ## Configuration AioKafkaBroker parameters: * `bootstrap_servers` - url to kafka nodes. Can be either string or list of strings. -* `kafka_topic` - custom topic in kafka. +* `kafka_topic` - default topic in kafka. +* `kafka_topics` - additional topics that worker should listen to. * `result_backend` - custom result backend. * `task_id_generator` - custom task_id genertaor. * `kafka_admin_client` - custom `kafka` admin client. -* `delete_topic_on_shutdown` - flag to delete topic on broker shutdown. +* `delete_topic_on_shutdown` - flag to delete topics on broker shutdown. diff --git a/taskiq_aio_kafka/__init__.py b/taskiq_aio_kafka/__init__.py index ec46428..a452231 100644 --- a/taskiq_aio_kafka/__init__.py +++ b/taskiq_aio_kafka/__init__.py @@ -1,5 +1,5 @@ """Taskiq integration with aiokafka.""" -from taskiq_aio_kafka.broker import AioKafkaBroker +__all__ = ("AioKafkaBroker",) -__all__ = ["AioKafkaBroker"] +from taskiq_aio_kafka.broker import AioKafkaBroker diff --git a/taskiq_aio_kafka/broker.py b/taskiq_aio_kafka/broker.py index 8f0fbd4..2b0290f 100644 --- a/taskiq_aio_kafka/broker.py +++ b/taskiq_aio_kafka/broker.py @@ -1,7 +1,9 @@ +__all__ = ("AioKafkaBroker",) + import asyncio -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Callable, Iterable from logging import getLogger -from typing import Any, TypeVar +from typing import Any, TypeAlias, TypeVar from aiokafka import AIOKafkaConsumer, AIOKafkaProducer from kafka.admin import KafkaAdminClient, NewTopic @@ -9,23 +11,68 @@ from kafka.partitioner.default import DefaultPartitioner from taskiq import AsyncResultBackend, BrokerMessage from taskiq.abc.broker import AsyncBroker +from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.kicker import AsyncKicker +from typing_extensions import ParamSpec from taskiq_aio_kafka.exceptions import WrongAioKafkaBrokerParametersError from taskiq_aio_kafka.models import KafkaConsumerParameters, KafkaProducerParameters +from taskiq_aio_kafka.topic import Topic _T = TypeVar("_T") +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") +TopicType: TypeAlias = str | NewTopic | Topic +TASK_TOPIC_LABEL = "taskiq_aio_kafka_topic" logger = getLogger("taskiq.kafka_broker") +def _get_topic_name(topic: TopicType) -> str: + if isinstance(topic, str): + return topic + return topic.name + + +class AioKafkaKicker(AsyncKicker[_FuncParams, _ReturnType]): + """Kicker that can override kafka topic for a task call.""" + + def with_topic( + self, + topic: TopicType, + ) -> "AioKafkaKicker[_FuncParams, _ReturnType]": + """Set kafka topic for current kick.""" + self.labels = { + **self.labels, + TASK_TOPIC_LABEL: _get_topic_name(topic), + } + return self + + +class AioKafkaDecoratedTask(AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]): + """Taskiq decorated task with kafka-specific kicker.""" + + def kicker(self) -> AioKafkaKicker[_FuncParams, _ReturnType]: + """Return kafka-aware kicker.""" + return AioKafkaKicker( + task_name=self.task_name, + broker=self.broker, + labels=self.labels, + return_type=self.return_type, + ) + + class AioKafkaBroker(AsyncBroker): """Broker that works with Kafka.""" + task_topic_label = TASK_TOPIC_LABEL + def __init__( # noqa: PLR0913 self, bootstrap_servers: str | list[str] | None, - kafka_topic: NewTopic | None = None, + kafka_topic: TopicType | None = None, + kafka_topics: Iterable[TopicType] | None = None, result_backend: AsyncResultBackend[_T] | None = None, task_id_generator: Callable[[], str] | None = None, kafka_admin_client: KafkaAdminClient | None = None, @@ -35,7 +82,8 @@ def __init__( # noqa: PLR0913 """Construct a new broker. :param bootstrap_servers: string with url to kafka or list with urls. - :param kafka_topic: kafka topic. + :param kafka_topic: default kafka topic. + :param kafka_topics: all kafka topics to listen. :param result_backend: custom result backend. :param task_id_generator: custom task_id generator. :param kafka_admin_client: configured KafkaAdminClient. @@ -46,6 +94,7 @@ def __init__( # noqa: PLR0913 aiokafka_consumer were specified but bootstrap_servers wasn't specified. """ super().__init__(result_backend, task_id_generator) + self.decorator_class = AioKafkaDecoratedTask if kafka_admin_client and not bootstrap_servers: raise WrongAioKafkaBrokerParametersError @@ -54,11 +103,16 @@ def __init__( # noqa: PLR0913 self._loop: asyncio.AbstractEventLoop | None = loop - self._kafka_topic: NewTopic = kafka_topic or NewTopic( - name="taskiq_topic", - num_partitions=1, - replication_factor=1, - ) + self._kafka_topic: NewTopic = self._normalize_default_topic(kafka_topic) + self._kafka_topics: dict[str, TopicType] = { + self._kafka_topic.name: self._kafka_topic, + } + if kafka_topics is not None: + for topic in kafka_topics: + self._kafka_topics.setdefault( + self._get_topic_name(topic), + topic, + ) self._aiokafka_producer_params: KafkaProducerParameters = ( KafkaProducerParameters() @@ -83,6 +137,48 @@ def __init__( # noqa: PLR0913 self._is_producer_started = False self._is_consumer_started = False + @staticmethod + def _get_topic_name(topic: TopicType) -> str: + return _get_topic_name(topic) + + @classmethod + def _normalize_default_topic( + cls, + kafka_topic: TopicType | None, + ) -> NewTopic: + if kafka_topic is None: + return NewTopic( + name="taskiq_topic", + num_partitions=1, + replication_factor=1, + ) + if isinstance(kafka_topic, str): + return NewTopic( + name=kafka_topic, + num_partitions=1, + replication_factor=1, + ) + if isinstance(kafka_topic, Topic): + if kafka_topic.topic_config.declare: + return kafka_topic.new_topic() + return NewTopic( + name=kafka_topic.name, + num_partitions=1, + replication_factor=1, + ) + return kafka_topic + + @classmethod + def _get_declaration_topic( + cls, + topic: TopicType, + ) -> NewTopic | None: + if isinstance(topic, NewTopic): + return topic + if isinstance(topic, Topic) and topic.topic_config.declare: + return topic.new_topic() + return None + def configure_producer(self, **producer_parameters: Any) -> None: """Configure kafka producer. @@ -107,21 +203,41 @@ def configure_consumer(self, **consumer_parameters: Any) -> None: **consumer_parameters, ) + def task( # type: ignore[override] + self, + task_name: str | Callable[..., Any] | None = None, + *, + topic: TopicType | None = None, + **labels: Any, + ) -> Any: + """Decorate function and bind it to a kafka topic by default.""" + if topic is not None: + topic_name = self._get_topic_name(topic) + self._kafka_topics.setdefault(topic_name, topic) + labels[self.task_topic_label] = topic_name + + return super().task( + task_name=task_name, # type: ignore[arg-type] + **labels, + ) + async def startup(self) -> None: """Setup AIOKafkaProducer, AIOKafkaConsumer and kafka topics. - We will have 2 topics for default and high priority. - Also we need to create AIOKafkaProducer and AIOKafkaConsumer if there are no producer and consumer passed. """ await super().startup() - available_condition: bool = ( - self._kafka_topic.name not in self._kafka_admin_client.list_topics() - ) - if available_condition: + existed_topic_names = set(self._kafka_admin_client.list_topics()) + new_topics = [ + new_topic + for topic in self._kafka_topics.values() + if (new_topic := self._get_declaration_topic(topic)) is not None + and new_topic.name not in existed_topic_names + ] + if new_topics: self._kafka_admin_client.create_topics( - new_topics=[self._kafka_topic], + new_topics=new_topics, validate_only=False, ) @@ -145,7 +261,7 @@ async def startup(self) -> None: partition_assignment_strategy ) self._aiokafka_consumer = AIOKafkaConsumer( - self._kafka_topic.name, + *self._kafka_topics, bootstrap_servers=self._bootstrap_servers, loop=self._loop, **consumer_kwargs, @@ -166,18 +282,16 @@ async def shutdown(self) -> None: if self._is_consumer_started: await self._aiokafka_consumer.stop() - topic_delete_condition: bool = all( - ( - self._delete_topic_on_shutdown, - self._kafka_topic.name in self._kafka_admin_client.list_topics(), - ), - ) - if self._kafka_admin_client: - if topic_delete_condition: - self._kafka_admin_client.delete_topics( - [self._kafka_topic.name], - ) + if self._delete_topic_on_shutdown: + existed_topic_names = set(self._kafka_admin_client.list_topics()) + topic_names = [ + topic_name + for topic_name in self._kafka_topics + if topic_name in existed_topic_names + ] + if topic_names: + self._kafka_admin_client.delete_topics(topic_names) self._kafka_admin_client.close() async def kick(self, message: BrokerMessage) -> None: @@ -194,7 +308,10 @@ async def kick(self, message: BrokerMessage) -> None: if not self._is_producer_started: raise ValueError("Please run startup before kicking.") - topic_name: str = self._kafka_topic.name + topic_name: str = message.labels.get( + self.task_topic_label, + self._kafka_topic.name, + ) await self._aiokafka_producer.send( topic=topic_name, diff --git a/taskiq_aio_kafka/topic.py b/taskiq_aio_kafka/topic.py new file mode 100644 index 0000000..f4b5a09 --- /dev/null +++ b/taskiq_aio_kafka/topic.py @@ -0,0 +1,39 @@ +__all__ = ( + "Topic", + "TopicConfig", +) + +import dataclasses + +from kafka.admin import NewTopic + + +@dataclasses.dataclass +class TopicConfig: + """Kafka topic declaration settings.""" + + declare: bool = False + num_partitions: int = 1 + replication_factor: int = 1 + replica_assignments: dict[int, list[int]] = dataclasses.field(default_factory=dict) + topic_configs: dict[str, str] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class Topic: + """Taskiq kafka topic.""" + + name: str + topic_config: TopicConfig = dataclasses.field(default_factory=TopicConfig) + + def new_topic(self) -> NewTopic: + """Create kafka-python NewTopic instance.""" + if not self.topic_config.declare: + raise ValueError("Topic declaration is disabled for this topic.") + return NewTopic( + name=self.name, + num_partitions=self.topic_config.num_partitions, + replication_factor=self.topic_config.replication_factor, + replica_assignments=self.topic_config.replica_assignments, + topic_configs=self.topic_config.topic_configs, + ) diff --git a/tests/test_broker_multi_topic.py b/tests/test_broker_multi_topic.py new file mode 100644 index 0000000..a252dbb --- /dev/null +++ b/tests/test_broker_multi_topic.py @@ -0,0 +1,189 @@ +import pickle +from unittest.mock import Mock + +import pytest +from kafka.admin import KafkaAdminClient, NewTopic +from taskiq import BrokerMessage + +from taskiq_aio_kafka.broker import AioKafkaBroker +from taskiq_aio_kafka.topic import Topic + + +class _ProducerMock: + """Kafka producer mock.""" + + def __init__(self) -> None: + self.messages: list[tuple[str, bytes]] = [] + + async def send(self, topic: str, value: bytes) -> None: + """Store produced message.""" + self.messages.append((topic, value)) + + +class _ProducerStartStopMock: + """Kafka producer lifecycle mock.""" + + async def start(self) -> None: + """Start producer.""" + + async def stop(self) -> None: + """Stop producer.""" + + +class _ConsumerStartStopMock: + """Kafka consumer lifecycle mock.""" + + async def start(self) -> None: + """Start consumer.""" + + async def stop(self) -> None: + """Stop consumer.""" + + +def get_admin_client_mock() -> KafkaAdminClient: + """Get kafka admin client mock.""" + admin_client = Mock(spec=KafkaAdminClient) + admin_client.list_topics.return_value = [] + return admin_client + + +async def test_task_topic_is_used_for_kick() -> None: + """Test that task is sent to its declared topic.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_topics=[Topic("extra-topic")], + kafka_admin_client=get_admin_client_mock(), + ) + producer = _ProducerMock() + broker._aiokafka_producer = producer + broker._is_producer_started = True + + @broker.task(topic="extra-topic") + async def test_task() -> None: + return None + + await test_task.kiq() + + assert producer.messages[0][0] == "extra-topic" + + +async def test_kicker_topic_overrides_task_default_topic() -> None: + """Test that kicker can override task default topic.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_topics=["extra-topic", "override-topic"], + kafka_admin_client=get_admin_client_mock(), + ) + producer = _ProducerMock() + broker._aiokafka_producer = producer + broker._is_producer_started = True + + @broker.task(topic="extra-topic") + async def test_task() -> None: + return None + + await test_task.kicker().with_topic("override-topic").kiq() + await test_task.kiq() + + assert producer.messages[0][0] == "override-topic" + assert producer.messages[1][0] == "extra-topic" + + +async def test_kicker_topic_is_used_without_task_default_topic() -> None: + """Test that kicker can set topic for task without declared topic.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_topics=["override-topic"], + kafka_admin_client=get_admin_client_mock(), + ) + producer = _ProducerMock() + broker._aiokafka_producer = producer + broker._is_producer_started = True + + @broker.task + async def test_task() -> None: + return None + + await test_task.kicker().with_topic("override-topic").kiq() + + assert producer.messages[0][0] == "override-topic" + + +async def test_kick_uses_default_topic_without_task_topic() -> None: + """Test that default topic is used when message has no topic label.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + producer = _ProducerMock() + broker._aiokafka_producer = producer + broker._is_producer_started = True + message = BrokerMessage( + task_id="task-id", + task_name="task-name", + message=pickle.dumps("message"), + labels={}, + ) + + await broker.kick(message) + + assert producer.messages == [("default-topic", message.message)] + + +def test_broker_collects_topic_names() -> None: + """Test that broker listens to default and task topics.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic=NewTopic( + name="default-topic", + num_partitions=1, + replication_factor=1, + ), + kafka_topics=[Topic("extra-topic")], + kafka_admin_client=get_admin_client_mock(), + ) + + assert set(broker._kafka_topics) == {"default-topic", "extra-topic"} + + +async def test_startup_subscribes_consumer_to_all_topics( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that worker consumer subscribes to all broker topics.""" + consumer_topics: tuple[str, ...] = () + + def create_producer(**_kwargs: object) -> _ProducerStartStopMock: + return _ProducerStartStopMock() + + def create_consumer( + *topics: str, + **_kwargs: object, + ) -> _ConsumerStartStopMock: + nonlocal consumer_topics + consumer_topics = topics + return _ConsumerStartStopMock() + + monkeypatch.setattr( + "taskiq_aio_kafka.broker.AIOKafkaProducer", + create_producer, + ) + monkeypatch.setattr( + "taskiq_aio_kafka.broker.AIOKafkaConsumer", + create_consumer, + ) + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_topics=[Topic("extra-topic")], + kafka_admin_client=get_admin_client_mock(), + ) + broker.is_worker_process = True + + await broker.startup() + await broker.shutdown() + + assert consumer_topics == ("default-topic", "extra-topic") From 20fbd5e8b6aee21d70a9fa7927f2c3e4b195c451 Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Tue, 28 Apr 2026 11:06:21 +0200 Subject: [PATCH 2/2] decomposition: 1 class 1 file update: task_with_topic interface update: tests --- README.md | 5 +- taskiq_aio_kafka/broker.py | 136 ++++++++++++++++------------- taskiq_aio_kafka/constants.py | 3 + taskiq_aio_kafka/decorated_task.py | 24 +++++ taskiq_aio_kafka/kicker.py | 28 ++++++ taskiq_aio_kafka/topic.py | 11 +-- taskiq_aio_kafka/topic_config.py | 14 +++ taskiq_aio_kafka/types.py | 9 ++ taskiq_aio_kafka/utils.py | 10 +++ tests/test_broker_multi_topic.py | 69 ++++++++++++++- 10 files changed, 234 insertions(+), 75 deletions(-) create mode 100644 taskiq_aio_kafka/constants.py create mode 100644 taskiq_aio_kafka/decorated_task.py create mode 100644 taskiq_aio_kafka/kicker.py create mode 100644 taskiq_aio_kafka/topic_config.py create mode 100644 taskiq_aio_kafka/types.py create mode 100644 taskiq_aio_kafka/utils.py diff --git a/README.md b/README.md index 9a3d155..79a052f 100644 --- a/README.md +++ b/README.md @@ -55,12 +55,12 @@ broker = AioKafkaBroker( ) -@broker.task(topic="emails") +@broker.task_with_topic("emails") async def send_email(user_id: int) -> None: print(f"Send email to {user_id}") -@broker.task(topic="reports") +@broker.task_with_topic("reports") async def build_report(report_id: int) -> None: print(f"Build report {report_id}") ``` @@ -76,6 +76,7 @@ await send_email.kicker().with_topic("reports").kiq(user_id=1) ``` Tasks without a custom topic keep the old behavior and are sent to `kafka_topic`. +The regular `@broker.task` decorator keeps the standard taskiq labels behavior. ```python @broker.task diff --git a/taskiq_aio_kafka/broker.py b/taskiq_aio_kafka/broker.py index 2b0290f..6c45c7c 100644 --- a/taskiq_aio_kafka/broker.py +++ b/taskiq_aio_kafka/broker.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import AsyncGenerator, Callable, Iterable from logging import getLogger -from typing import Any, TypeAlias, TypeVar +from typing import Any, TypeVar, overload from aiokafka import AIOKafkaConsumer, AIOKafkaProducer from kafka.admin import KafkaAdminClient, NewTopic @@ -11,58 +11,24 @@ from kafka.partitioner.default import DefaultPartitioner from taskiq import AsyncResultBackend, BrokerMessage from taskiq.abc.broker import AsyncBroker -from taskiq.decor import AsyncTaskiqDecoratedTask -from taskiq.kicker import AsyncKicker from typing_extensions import ParamSpec -from taskiq_aio_kafka.exceptions import WrongAioKafkaBrokerParametersError -from taskiq_aio_kafka.models import KafkaConsumerParameters, KafkaProducerParameters -from taskiq_aio_kafka.topic import Topic +from .constants import TASK_TOPIC_LABEL +from .decorated_task import AioKafkaDecoratedTask +from .exceptions import WrongAioKafkaBrokerParametersError +from .models import KafkaConsumerParameters, KafkaProducerParameters +from .topic import Topic +from .types import TopicType +from .utils import get_topic_name _T = TypeVar("_T") _FuncParams = ParamSpec("_FuncParams") _ReturnType = TypeVar("_ReturnType") -TopicType: TypeAlias = str | NewTopic | Topic -TASK_TOPIC_LABEL = "taskiq_aio_kafka_topic" logger = getLogger("taskiq.kafka_broker") -def _get_topic_name(topic: TopicType) -> str: - if isinstance(topic, str): - return topic - return topic.name - - -class AioKafkaKicker(AsyncKicker[_FuncParams, _ReturnType]): - """Kicker that can override kafka topic for a task call.""" - - def with_topic( - self, - topic: TopicType, - ) -> "AioKafkaKicker[_FuncParams, _ReturnType]": - """Set kafka topic for current kick.""" - self.labels = { - **self.labels, - TASK_TOPIC_LABEL: _get_topic_name(topic), - } - return self - - -class AioKafkaDecoratedTask(AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]): - """Taskiq decorated task with kafka-specific kicker.""" - - def kicker(self) -> AioKafkaKicker[_FuncParams, _ReturnType]: - """Return kafka-aware kicker.""" - return AioKafkaKicker( - task_name=self.task_name, - broker=self.broker, - labels=self.labels, - return_type=self.return_type, - ) - - class AioKafkaBroker(AsyncBroker): """Broker that works with Kafka.""" @@ -110,7 +76,7 @@ def __init__( # noqa: PLR0913 if kafka_topics is not None: for topic in kafka_topics: self._kafka_topics.setdefault( - self._get_topic_name(topic), + get_topic_name(topic), topic, ) @@ -137,10 +103,6 @@ def __init__( # noqa: PLR0913 self._is_producer_started = False self._is_consumer_started = False - @staticmethod - def _get_topic_name(topic: TopicType) -> str: - return _get_topic_name(topic) - @classmethod def _normalize_default_topic( cls, @@ -203,21 +165,72 @@ def configure_consumer(self, **consumer_parameters: Any) -> None: **consumer_parameters, ) - def task( # type: ignore[override] + @overload + def task( + self, + task_name: Callable[_FuncParams, _ReturnType], + **labels: Any, + ) -> AioKafkaDecoratedTask[_FuncParams, _ReturnType]: ... + + @overload + def task( + self, + task_name: str | None = None, + **labels: Any, + ) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + AioKafkaDecoratedTask[_FuncParams, _ReturnType], + ]: ... + + def task( + self, + task_name: str | Callable[..., Any] | None = None, + **labels: Any, + ) -> Any: + """Decorate function.""" + if callable(task_name): + return super().task(task_name, **labels) + + return super().task( + task_name=task_name, + **labels, + ) + + @overload + def task_with_topic( + self, + topic: TopicType, + task_name: Callable[_FuncParams, _ReturnType], + **labels: Any, + ) -> AioKafkaDecoratedTask[_FuncParams, _ReturnType]: ... + + @overload + def task_with_topic( + self, + topic: TopicType, + task_name: str | None = None, + **labels: Any, + ) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + AioKafkaDecoratedTask[_FuncParams, _ReturnType], + ]: ... + + def task_with_topic( self, + topic: TopicType, task_name: str | Callable[..., Any] | None = None, - *, - topic: TopicType | None = None, **labels: Any, ) -> Any: """Decorate function and bind it to a kafka topic by default.""" - if topic is not None: - topic_name = self._get_topic_name(topic) - self._kafka_topics.setdefault(topic_name, topic) - labels[self.task_topic_label] = topic_name + topic_name = get_topic_name(topic) + self._kafka_topics.setdefault(topic_name, topic) + labels[self.task_topic_label] = topic_name + + if callable(task_name): + return super().task(task_name, **labels) return super().task( - task_name=task_name, # type: ignore[arg-type] + task_name=task_name, **labels, ) @@ -229,12 +242,13 @@ async def startup(self) -> None: """ await super().startup() existed_topic_names = set(self._kafka_admin_client.list_topics()) - new_topics = [ - new_topic - for topic in self._kafka_topics.values() - if (new_topic := self._get_declaration_topic(topic)) is not None - and new_topic.name not in existed_topic_names - ] + + new_topics = [] + for topic in self._kafka_topics.values(): + new_topic = self._get_declaration_topic(topic) + if new_topic is not None and new_topic.name not in existed_topic_names: + new_topics.append(new_topic) + if new_topics: self._kafka_admin_client.create_topics( new_topics=new_topics, diff --git a/taskiq_aio_kafka/constants.py b/taskiq_aio_kafka/constants.py new file mode 100644 index 0000000..71f5fad --- /dev/null +++ b/taskiq_aio_kafka/constants.py @@ -0,0 +1,3 @@ +__all__ = ("TASK_TOPIC_LABEL",) + +TASK_TOPIC_LABEL = "taskiq_aio_kafka_topic" diff --git a/taskiq_aio_kafka/decorated_task.py b/taskiq_aio_kafka/decorated_task.py new file mode 100644 index 0000000..e174dcd --- /dev/null +++ b/taskiq_aio_kafka/decorated_task.py @@ -0,0 +1,24 @@ +__all__ = ("AioKafkaDecoratedTask",) + +from typing import TypeVar + +from taskiq.decor import AsyncTaskiqDecoratedTask +from typing_extensions import ParamSpec + +from .kicker import AioKafkaKicker + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + + +class AioKafkaDecoratedTask(AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]): + """Taskiq decorated task with kafka-specific kicker.""" + + def kicker(self) -> AioKafkaKicker[_FuncParams, _ReturnType]: + """Return kafka-aware kicker.""" + return AioKafkaKicker( + task_name=self.task_name, + broker=self.broker, + labels=self.labels, + return_type=self.return_type, + ) diff --git a/taskiq_aio_kafka/kicker.py b/taskiq_aio_kafka/kicker.py new file mode 100644 index 0000000..042e6cc --- /dev/null +++ b/taskiq_aio_kafka/kicker.py @@ -0,0 +1,28 @@ +__all__ = ("AioKafkaKicker",) + +from typing import TypeVar + +from taskiq.kicker import AsyncKicker +from typing_extensions import ParamSpec + +from .constants import TASK_TOPIC_LABEL +from .types import TopicType +from .utils import get_topic_name + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + + +class AioKafkaKicker(AsyncKicker[_FuncParams, _ReturnType]): + """Kicker that can override kafka topic for a task call.""" + + def with_topic( + self, + topic: TopicType, + ) -> "AioKafkaKicker[_FuncParams, _ReturnType]": + """Set kafka topic for current kick.""" + self.labels = { + **self.labels, + TASK_TOPIC_LABEL: get_topic_name(topic), + } + return self diff --git a/taskiq_aio_kafka/topic.py b/taskiq_aio_kafka/topic.py index f4b5a09..cf5efd4 100644 --- a/taskiq_aio_kafka/topic.py +++ b/taskiq_aio_kafka/topic.py @@ -7,16 +7,7 @@ from kafka.admin import NewTopic - -@dataclasses.dataclass -class TopicConfig: - """Kafka topic declaration settings.""" - - declare: bool = False - num_partitions: int = 1 - replication_factor: int = 1 - replica_assignments: dict[int, list[int]] = dataclasses.field(default_factory=dict) - topic_configs: dict[str, str] = dataclasses.field(default_factory=dict) +from .topic_config import TopicConfig @dataclasses.dataclass diff --git a/taskiq_aio_kafka/topic_config.py b/taskiq_aio_kafka/topic_config.py new file mode 100644 index 0000000..0e0d201 --- /dev/null +++ b/taskiq_aio_kafka/topic_config.py @@ -0,0 +1,14 @@ +__all__ = ("TopicConfig",) + +import dataclasses + + +@dataclasses.dataclass +class TopicConfig: + """Kafka topic declaration settings.""" + + declare: bool = False + num_partitions: int = 1 + replication_factor: int = 1 + replica_assignments: dict[int, list[int]] = dataclasses.field(default_factory=dict) + topic_configs: dict[str, str] = dataclasses.field(default_factory=dict) diff --git a/taskiq_aio_kafka/types.py b/taskiq_aio_kafka/types.py new file mode 100644 index 0000000..b63291f --- /dev/null +++ b/taskiq_aio_kafka/types.py @@ -0,0 +1,9 @@ +__all__ = ("TopicType",) + +from typing import TypeAlias + +from kafka.admin import NewTopic + +from .topic import Topic + +TopicType: TypeAlias = str | NewTopic | Topic diff --git a/taskiq_aio_kafka/utils.py b/taskiq_aio_kafka/utils.py new file mode 100644 index 0000000..542baa6 --- /dev/null +++ b/taskiq_aio_kafka/utils.py @@ -0,0 +1,10 @@ +__all__ = ("get_topic_name",) + +from .types import TopicType + + +def get_topic_name(topic: TopicType) -> str: + """Get kafka topic name.""" + if isinstance(topic, str): + return topic + return topic.name diff --git a/tests/test_broker_multi_topic.py b/tests/test_broker_multi_topic.py index a252dbb..fb5e850 100644 --- a/tests/test_broker_multi_topic.py +++ b/tests/test_broker_multi_topic.py @@ -59,7 +59,7 @@ async def test_task_topic_is_used_for_kick() -> None: broker._aiokafka_producer = producer broker._is_producer_started = True - @broker.task(topic="extra-topic") + @broker.task_with_topic("extra-topic") async def test_task() -> None: return None @@ -68,6 +68,28 @@ async def test_task() -> None: assert producer.messages[0][0] == "extra-topic" +async def test_task_topic_object_is_used_for_kick() -> None: + """Test that task can be bound to Topic object.""" + extra_topic = Topic("extra-topic") + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_topics=[extra_topic], + kafka_admin_client=get_admin_client_mock(), + ) + producer = _ProducerMock() + broker._aiokafka_producer = producer + broker._is_producer_started = True + + @broker.task_with_topic(extra_topic) + async def test_task() -> None: + return None + + await test_task.kiq() + + assert producer.messages[0][0] == extra_topic.name + + async def test_kicker_topic_overrides_task_default_topic() -> None: """Test that kicker can override task default topic.""" broker = AioKafkaBroker( @@ -80,7 +102,7 @@ async def test_kicker_topic_overrides_task_default_topic() -> None: broker._aiokafka_producer = producer broker._is_producer_started = True - @broker.task(topic="extra-topic") + @broker.task_with_topic("extra-topic") async def test_task() -> None: return None @@ -91,6 +113,49 @@ async def test_task() -> None: assert producer.messages[1][0] == "extra-topic" +async def test_kicker_topic_object_overrides_task_default_topic() -> None: + """Test that kicker can override task default topic with Topic object.""" + override_topic = Topic("override-topic") + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_topics=["extra-topic", override_topic], + kafka_admin_client=get_admin_client_mock(), + ) + producer = _ProducerMock() + broker._aiokafka_producer = producer + broker._is_producer_started = True + + @broker.task_with_topic("extra-topic") + async def test_task() -> None: + return None + + await test_task.kicker().with_topic(override_topic).kiq() + + assert producer.messages[0][0] == override_topic.name + + +async def test_task_topic_label_keeps_default_broker_topic() -> None: + """Test that regular task topic label doesn't override kafka topic.""" + broker = AioKafkaBroker( + bootstrap_servers="localhost", + kafka_topic="default-topic", + kafka_admin_client=get_admin_client_mock(), + ) + producer = _ProducerMock() + broker._aiokafka_producer = producer + broker._is_producer_started = True + + @broker.task(topic="regular-label") + async def test_task() -> None: + return None + + await test_task.kiq() + + assert test_task.labels["topic"] == "regular-label" + assert producer.messages[0][0] == "default-topic" + + async def test_kicker_topic_is_used_without_task_default_topic() -> None: """Test that kicker can set topic for task without declared topic.""" broker = AioKafkaBroker(