From 0a0d4d07db8da3318aebc4c2989f893fe532aedb Mon Sep 17 00:00:00 2001 From: roveil Date: Thu, 5 Mar 2026 15:00:17 +0500 Subject: [PATCH 1/2] use multiple delayed queues for each queue instead of single --- taskiq_aio_pika/broker.py | 135 ++++++++++++++++++++++---------------- taskiq_aio_pika/queue.py | 15 +++++ tests/conftest.py | 18 +---- tests/test_delay.py | 3 +- tests/test_routing.py | 10 +-- tests/test_startup.py | 26 +++----- 6 files changed, 110 insertions(+), 97 deletions(-) diff --git a/taskiq_aio_pika/broker.py b/taskiq_aio_pika/broker.py index 0cc9601..9b25562 100644 --- a/taskiq_aio_pika/broker.py +++ b/taskiq_aio_pika/broker.py @@ -59,7 +59,6 @@ def __init__( exchange: Exchange | None = None, task_queues: list[Queue] | None = None, dead_letter_queue: Queue | None = None, - delay_queue: Queue | None = None, delayed_message_exchange_plugin: bool = False, delayed_message_exchange: Exchange | None = None, label_for_routing: str = "queue_name", @@ -79,7 +78,6 @@ def __init__( :param task_queues: parameters of queues that will be used to get incoming messages. :param dead_letter_queue: parameters of dead-letter queue. - :param delay_queue: parameters of queue for simple delay implementation. :param delayed_message_exchange_plugin: turn on or disable delayed-message-exchange rabbitmq plugin. :param delayed_message_exchange: parameters of exchange @@ -96,15 +94,18 @@ def __init__( self._conn_kwargs = connection_kwargs self._exchange = exchange or Exchange() self._qos = qos - self._task_queues = task_queues or [] + self._task_queues = task_queues or [Queue()] + self._task_queues_by_routing_key = { + queue.queue_routing_key: queue for queue in self._task_queues + } + self._dead_letter_queue = dead_letter_queue or Queue(name="taskiq.dead_letter") self._label_for_routing = label_for_routing self._label_for_priority = label_for_priority - self._delay_queue = delay_queue - self._delayed_message_exchange_plugin = delayed_message_exchange_plugin + if self._delayed_message_exchange_plugin: self._delayed_message_exchange = delayed_message_exchange or Exchange( name=f"{self._exchange.name}.plugin_delay", @@ -176,29 +177,31 @@ async def _declare_exchanges( f"was not declared and does not exist.", ) from error - if self._delayed_message_exchange_plugin: - if self._delayed_message_exchange.declare: - await self.write_channel.declare_exchange( + if not self._delayed_message_exchange_plugin: + return + + if self._delayed_message_exchange.declare: + await self.write_channel.declare_exchange( + name=self._delayed_message_exchange.name, + type=self._delayed_message_exchange.type, + durable=self._delayed_message_exchange.durable, + auto_delete=self._delayed_message_exchange.auto_delete, + internal=self._delayed_message_exchange.internal, + passive=self._delayed_message_exchange.passive, + arguments=self._delayed_message_exchange.arguments, + timeout=self._delayed_message_exchange.timeout, + ) + else: + try: + await self.write_channel.get_exchange( name=self._delayed_message_exchange.name, - type=self._delayed_message_exchange.type, - durable=self._delayed_message_exchange.durable, - auto_delete=self._delayed_message_exchange.auto_delete, - internal=self._delayed_message_exchange.internal, - passive=self._delayed_message_exchange.passive, - arguments=self._delayed_message_exchange.arguments, - timeout=self._delayed_message_exchange.timeout, + ensure=True, ) - else: - try: - await self.write_channel.get_exchange( - name=self._delayed_message_exchange.name, - ensure=True, - ) - except aiormq.exceptions.ChannelNotFoundEntity as error: - raise ExchangeNotDeclaredError( - f"Exchange '{self._delayed_message_exchange.name}' " - f"was not declared and does not exist.", - ) from error + except aiormq.exceptions.ChannelNotFoundEntity as error: + raise ExchangeNotDeclaredError( + f"Exchange '{self._delayed_message_exchange.name}' " + f"was not declared and does not exist.", + ) from error async def _declare_dead_letter_queue( self, @@ -253,29 +256,19 @@ async def _declare_queues( declared_queues = [] queue_default_arguments: FieldTable = { "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": ( - self._dead_letter_queue.routing_key or self._dead_letter_queue.name - ), + "x-dead-letter-routing-key": self._dead_letter_queue.queue_routing_key, } - if not self._task_queues: # add default queue if user didn't provide any - self._task_queues.append(Queue()) queues = self._task_queues.copy() - if not self._delayed_message_exchange_plugin and self._delay_queue: - queues.append(self._delay_queue) for queue in filter(lambda queue: queue.declare, queues): per_queue_arguments: FieldTable = queue_default_arguments.copy() + if queue.max_priority is not None: per_queue_arguments["x-max-priority"] = queue.max_priority + per_queue_arguments["x-queue-type"] = queue.type.value - if self._delay_queue and queue.name == self._delay_queue.name: - per_queue_arguments["x-dead-letter-exchange"] = self._exchange.name - per_queue_arguments["x-dead-letter-routing-key"] = ( - self._delay_queue.routing_key - or queues[0].routing_key - or queues[0].name - ) + per_queue_arguments.update( queue.arguments if queue.arguments is not None else {}, ) @@ -288,22 +281,47 @@ async def _declare_queues( arguments=per_queue_arguments, timeout=queue.timeout, ) + logger.debug( "Bind queue to exchange with routing key '%s'", queue.routing_key or queue.name, ) - if not self._delay_queue or queue.name != self._delay_queue.name: - await declared_queue.bind( - exchange=self._exchange.name, - routing_key=queue.routing_key or queue.name, - arguments=queue.bind_arguments, - timeout=queue.bind_timeout, - ) + await declared_queue.bind( + exchange=self._exchange.name, + routing_key=queue.routing_key or queue.name, + arguments=queue.bind_arguments, + timeout=queue.bind_timeout, + ) + if self._delayed_message_exchange_plugin: await declared_queue.bind( exchange=self._delayed_message_exchange.name, routing_key=queue.routing_key or queue.name, ) + else: + # Declare delay queue with x-dead-letter-routing-key to queue + per_queue_arguments["x-dead-letter-exchange"] = self._exchange.name + per_queue_arguments["x-dead-letter-routing-key"] = ( + queue.queue_routing_key + ) + + logger.debug( + "Declare delay queue '%s' with 'x-dead-letter-routing-key': %s" + " and exchange '%s'", + queue.delay_queue_name, + queue.queue_routing_key, + self._exchange.name, + ) + await channel.declare_queue( + name=queue.delay_queue_name, + durable=queue.durable, + exclusive=queue.exclusive, + passive=queue.passive, + auto_delete=queue.auto_delete, + arguments=per_queue_arguments, + timeout=queue.timeout, + ) + declared_queues.append((declared_queue, queue.consumer_arguments)) for queue in filter(lambda queue: not queue.declare, queues): @@ -330,6 +348,8 @@ def with_queue(self, queue: Queue) -> Self: :return: self. """ self._task_queues.append(queue) + self._task_queues_by_routing_key[queue.queue_routing_key] = queue + return self def with_queues(self, *queues: Queue) -> Self: @@ -340,6 +360,10 @@ def with_queues(self, *queues: Queue) -> Self: :return: self. """ self._task_queues = list(queues) + self._task_queues_by_routing_key = { + queue.queue_routing_key: queue for queue in self._task_queues + } + return self async def kick(self, message: BrokerMessage) -> None: @@ -373,9 +397,8 @@ async def kick(self, message: BrokerMessage) -> None: delay = parse_val(float, message.labels.get("delay")) if len(self._task_queues) == 1: - routing_key_name = ( - self._task_queues[0].routing_key or self._task_queues[0].name - ) + queue: Queue | None = self._task_queues[0] + routing_key_name = queue.queue_routing_key else: routing_key_name = ( parse_val( @@ -384,9 +407,10 @@ async def kick(self, message: BrokerMessage) -> None: ) or "" ) - if self._exchange.type == ExchangeType.DIRECT and routing_key_name not in { - queue.routing_key or queue.name for queue in self._task_queues - }: + + queue: Queue | None = self._task_queues_by_routing_key.get(routing_key_name) + + if self._exchange.type == ExchangeType.DIRECT and queue is None: raise IncorrectRoutingKeyError( f"Routing key '{routing_key_name}' is not valid. " f"Check routing keys and queue names in broker queues.", @@ -404,11 +428,11 @@ async def kick(self, message: BrokerMessage) -> None: self._delayed_message_exchange.name, ) await exchange.publish(rmq_message, routing_key=routing_key_name) - elif self._delay_queue: + elif queue is not None: rmq_message.expiration = timedelta(seconds=delay) await self.write_channel.default_exchange.publish( rmq_message, - routing_key=self._delay_queue.routing_key or self._delay_queue.name, + routing_key=queue.delay_queue_routing_key, ) else: raise IncorrectRoutingKeyError( @@ -450,7 +474,6 @@ async def body( *[ body(queue, consumer_args) for queue, consumer_args in queue_with_consumer_args_list - if not self._delay_queue or queue.name != self._delay_queue.name ], ) diff --git a/taskiq_aio_pika/queue.py b/taskiq_aio_pika/queue.py index e3beafc..356c6a0 100644 --- a/taskiq_aio_pika/queue.py +++ b/taskiq_aio_pika/queue.py @@ -56,3 +56,18 @@ class Queue: # will be used during message consumption consumer_arguments: FieldTable = field(default_factory=dict) + + @property + def delay_queue_name(self) -> str: + """Return the name of the delay queue for this queue.""" + return f"{self.name}.delay" + + @property + def delay_queue_routing_key(self) -> str: + """Return the routing key used to publish messages to the delay queue.""" + return self.delay_queue_name + + @property + def queue_routing_key(self) -> str: + """Return the effective routing key for this queue.""" + return self.routing_key or self.name diff --git a/tests/conftest.py b/tests/conftest.py index e164b75..c41c1f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,16 +45,6 @@ def queue_name() -> str: return uuid4().hex + "_queue" -@pytest.fixture -def delay_queue_name() -> str: - """ - Generated name for delay queue. - - :return: random exchange name. - """ - return uuid4().hex + "_delay_queue" - - @pytest.fixture def dead_queue_name() -> str: """ @@ -141,7 +131,6 @@ async def _cleanup_amqp_resources( async def broker( amqp_url: str, queue_name: str, - delay_queue_name: str, dead_queue_name: str, exchange_name: str, test_channel: Channel, @@ -157,11 +146,6 @@ async def broker( declare=True, type=QueueType.CLASSIC, ), - delay_queue=Queue( - name=delay_queue_name, - declare=True, - type=QueueType.CLASSIC, - ), task_queues=[ Queue( name=queue_name, @@ -181,7 +165,7 @@ async def broker( await _cleanup_amqp_resources( amqp_url, [exchange_name], - [queue_name, delay_queue_name, dead_queue_name], + [queue_name, f"{queue_name}.delay", dead_queue_name], ) diff --git a/tests/test_delay.py b/tests/test_delay.py index 2c8b62f..74d5729 100644 --- a/tests/test_delay.py +++ b/tests/test_delay.py @@ -12,9 +12,8 @@ async def test_when_delayed_message_queue_exists__then_send_with_delay_must_work broker: AioPikaBroker, test_channel: Channel, queue_name: str, - delay_queue_name: str, ) -> None: - delay_queue = await test_channel.get_queue(delay_queue_name) + delay_queue = await test_channel.get_queue(f"{queue_name}.delay") main_queue = await test_channel.get_queue(queue_name) broker_msg = BrokerMessage( task_id="1", diff --git a/tests/test_routing.py b/tests/test_routing.py index 1a21852..e34204b 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -23,11 +23,11 @@ async def cleanup_class_broker(self, amqp_url: str) -> AsyncGenerator[None, None yield if self.broker is not None: await self.broker.shutdown() - queue_names = [queue.name for queue in self.broker._task_queues] + [ - self.broker._dead_letter_queue.name, - ] - if self.broker._delay_queue is not None: - queue_names.append(self.broker._delay_queue.name) + queue_names = ( + [queue.name for queue in self.broker._task_queues] + + [queue.delay_queue_name for queue in self.broker._task_queues] + + [self.broker._dead_letter_queue.name] + ) await _cleanup_amqp_resources( amqp_url, [self.broker._exchange.name], diff --git a/tests/test_startup.py b/tests/test_startup.py index 36cbd18..49f4459 100644 --- a/tests/test_startup.py +++ b/tests/test_startup.py @@ -8,7 +8,7 @@ from taskiq_aio_pika import AioPikaBroker from taskiq_aio_pika.exceptions import ExchangeNotDeclaredError, QueueNotDeclaredError from taskiq_aio_pika.exchange import Exchange -from taskiq_aio_pika.queue import Queue, QueueType +from taskiq_aio_pika.queue import Queue from tests.conftest import _cleanup_amqp_resources @@ -20,11 +20,11 @@ async def cleanup_class_broker(self, amqp_url: str) -> AsyncGenerator[None, None yield if self.broker is not None: await self.broker.shutdown() - queue_names = [queue.name for queue in self.broker._task_queues] + [ - self.broker._dead_letter_queue.name, - ] - if self.broker._delay_queue is not None: - queue_names.append(self.broker._delay_queue.name) + queue_names = ( + [queue.name for queue in self.broker._task_queues] + + [queue.delay_queue_name for queue in self.broker._task_queues] + + [self.broker._dead_letter_queue.name] + ) await _cleanup_amqp_resources( amqp_url, [self.broker._exchange.name], @@ -70,7 +70,6 @@ async def test_when_declare_flag_not_passed_to_queue__broker_does_not_declare_qu amqp_url: str, test_channel: Channel, exchange_name: str, - delay_queue_name: str, ) -> None: # given self.broker = AioPikaBroker( @@ -80,12 +79,6 @@ async def test_when_declare_flag_not_passed_to_queue__broker_does_not_declare_qu declare=True, durable=False, ), - delay_queue=Queue( - name=delay_queue_name, - type=QueueType.CLASSIC, - declare=True, - durable=False, - ), ) not_declared_queue_name = "not_declared_queue" + uuid.uuid4().hex @@ -229,7 +222,7 @@ async def test_when_delayed_message_exchange_plugin_enabled_and_custom_exchange_ ): await self.broker.startup() - async def test_when_delay_queue_not_specified__broker_does_not_create_delay_queue( + async def test_when_broker_starts__delay_queue_is_created_for_each_task_queue( self, amqp_url: str, test_channel: Channel, @@ -248,6 +241,5 @@ async def test_when_delay_queue_not_specified__broker_does_not_create_delay_queu await self.broker.startup() # then - assert self.broker._delay_queue is None - with pytest.raises(aiormq.exceptions.ChannelNotFoundEntity): - await test_channel.get_queue("taskiq.delay", ensure=True) + delay_queue = await test_channel.get_queue("taskiq.delay", ensure=True) + assert delay_queue.name == "taskiq.delay" From 5ab66c748c3420d905d61a2d4dd18dc2d6694fcc Mon Sep 17 00:00:00 2001 From: roveil Date: Thu, 5 Mar 2026 15:43:15 +0500 Subject: [PATCH 2/2] Add worker_queues_to_consume parameter. Allow to provide worker queues to consume --- taskiq_aio_pika/broker.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/taskiq_aio_pika/broker.py b/taskiq_aio_pika/broker.py index 9b25562..5901ecf 100644 --- a/taskiq_aio_pika/broker.py +++ b/taskiq_aio_pika/broker.py @@ -1,5 +1,5 @@ import asyncio -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Callable, Iterable from datetime import timedelta from logging import getLogger from typing import Any, TypeVar @@ -63,6 +63,7 @@ def __init__( delayed_message_exchange: Exchange | None = None, label_for_routing: str = "queue_name", label_for_priority: str = "priority", + worker_queues_to_consume: Iterable[str] | None = None, **connection_kwargs: Any, ) -> None: """ @@ -86,6 +87,8 @@ def __init__( :param label_for_priority: label name to use for message priority. :param connection_kwargs: additional keyword arguments, for connect_robust method of aio-pika. + :param worker_queues_to_consume: if provided, worker will only consume messages + from the specified queue names. By default, all task queues are consumed. """ super().__init__(result_backend, task_id_generator) @@ -98,6 +101,7 @@ def __init__( self._task_queues_by_routing_key = { queue.queue_routing_key: queue for queue in self._task_queues } + self._worker_queues_to_consume = worker_queues_to_consume self._dead_letter_queue = dead_letter_queue or Queue(name="taskiq.dead_letter") @@ -474,6 +478,8 @@ async def body( *[ body(queue, consumer_args) for queue, consumer_args in queue_with_consumer_args_list + if self._worker_queues_to_consume is None + or queue.name in self._worker_queues_to_consume ], )