Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 86 additions & 57 deletions taskiq_aio_pika/broker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from collections.abc import AsyncGenerator, Callable
from collections.abc import AsyncGenerator, Callable, Iterable
from datetime import timedelta
from logging import getLogger
from typing import Any, TypeVar
Expand Down Expand Up @@ -59,11 +59,11 @@ def __init__(
exchange: Exchange | None = None,
task_queues: list[Queue] | None = None,
dead_letter_queue: Queue | None = None,
delay_queue: Queue | None = None,
delayed_message_exchange_plugin: bool = False,
delayed_message_exchange: Exchange | None = None,
label_for_routing: str = "queue_name",
label_for_priority: str = "priority",
worker_queues_to_consume: Iterable[str] | None = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -79,7 +79,6 @@ def __init__(
:param task_queues: parameters of queues
that will be used to get incoming messages.
:param dead_letter_queue: parameters of dead-letter queue.
:param delay_queue: parameters of queue for simple delay implementation.
:param delayed_message_exchange_plugin: turn on or disable
delayed-message-exchange rabbitmq plugin.
:param delayed_message_exchange: parameters of exchange
Expand All @@ -88,6 +87,8 @@ def __init__(
:param label_for_priority: label name to use for message priority.
:param connection_kwargs: additional keyword arguments,
for connect_robust method of aio-pika.
:param worker_queues_to_consume: if provided, worker will only consume messages
from the specified queue names. By default, all task queues are consumed.
"""
super().__init__(result_backend, task_id_generator)

Expand All @@ -96,15 +97,19 @@ def __init__(
self._conn_kwargs = connection_kwargs
self._exchange = exchange or Exchange()
self._qos = qos
self._task_queues = task_queues or []
self._task_queues = task_queues or [Queue()]
self._task_queues_by_routing_key = {
queue.queue_routing_key: queue for queue in self._task_queues
}
self._worker_queues_to_consume = worker_queues_to_consume

self._dead_letter_queue = dead_letter_queue or Queue(name="taskiq.dead_letter")

self._label_for_routing = label_for_routing
self._label_for_priority = label_for_priority

self._delay_queue = delay_queue

self._delayed_message_exchange_plugin = delayed_message_exchange_plugin

if self._delayed_message_exchange_plugin:
self._delayed_message_exchange = delayed_message_exchange or Exchange(
name=f"{self._exchange.name}.plugin_delay",
Expand Down Expand Up @@ -176,29 +181,31 @@ async def _declare_exchanges(
f"was not declared and does not exist.",
) from error

if self._delayed_message_exchange_plugin:
if self._delayed_message_exchange.declare:
await self.write_channel.declare_exchange(
if not self._delayed_message_exchange_plugin:
return

if self._delayed_message_exchange.declare:
await self.write_channel.declare_exchange(
name=self._delayed_message_exchange.name,
type=self._delayed_message_exchange.type,
durable=self._delayed_message_exchange.durable,
auto_delete=self._delayed_message_exchange.auto_delete,
internal=self._delayed_message_exchange.internal,
passive=self._delayed_message_exchange.passive,
arguments=self._delayed_message_exchange.arguments,
timeout=self._delayed_message_exchange.timeout,
)
else:
try:
await self.write_channel.get_exchange(
name=self._delayed_message_exchange.name,
type=self._delayed_message_exchange.type,
durable=self._delayed_message_exchange.durable,
auto_delete=self._delayed_message_exchange.auto_delete,
internal=self._delayed_message_exchange.internal,
passive=self._delayed_message_exchange.passive,
arguments=self._delayed_message_exchange.arguments,
timeout=self._delayed_message_exchange.timeout,
ensure=True,
)
else:
try:
await self.write_channel.get_exchange(
name=self._delayed_message_exchange.name,
ensure=True,
)
except aiormq.exceptions.ChannelNotFoundEntity as error:
raise ExchangeNotDeclaredError(
f"Exchange '{self._delayed_message_exchange.name}' "
f"was not declared and does not exist.",
) from error
except aiormq.exceptions.ChannelNotFoundEntity as error:
raise ExchangeNotDeclaredError(
f"Exchange '{self._delayed_message_exchange.name}' "
f"was not declared and does not exist.",
) from error

async def _declare_dead_letter_queue(
self,
Expand Down Expand Up @@ -253,29 +260,19 @@ async def _declare_queues(
declared_queues = []
queue_default_arguments: FieldTable = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": (
self._dead_letter_queue.routing_key or self._dead_letter_queue.name
),
"x-dead-letter-routing-key": self._dead_letter_queue.queue_routing_key,
}
if not self._task_queues: # add default queue if user didn't provide any
self._task_queues.append(Queue())

queues = self._task_queues.copy()
if not self._delayed_message_exchange_plugin and self._delay_queue:
queues.append(self._delay_queue)

for queue in filter(lambda queue: queue.declare, queues):
per_queue_arguments: FieldTable = queue_default_arguments.copy()

if queue.max_priority is not None:
per_queue_arguments["x-max-priority"] = queue.max_priority

per_queue_arguments["x-queue-type"] = queue.type.value
if self._delay_queue and queue.name == self._delay_queue.name:
per_queue_arguments["x-dead-letter-exchange"] = self._exchange.name
per_queue_arguments["x-dead-letter-routing-key"] = (
self._delay_queue.routing_key
or queues[0].routing_key
or queues[0].name
)

per_queue_arguments.update(
queue.arguments if queue.arguments is not None else {},
)
Expand All @@ -288,22 +285,47 @@ async def _declare_queues(
arguments=per_queue_arguments,
timeout=queue.timeout,
)

logger.debug(
"Bind queue to exchange with routing key '%s'",
queue.routing_key or queue.name,
)
if not self._delay_queue or queue.name != self._delay_queue.name:
await declared_queue.bind(
exchange=self._exchange.name,
routing_key=queue.routing_key or queue.name,
arguments=queue.bind_arguments,
timeout=queue.bind_timeout,
)
await declared_queue.bind(
exchange=self._exchange.name,
routing_key=queue.routing_key or queue.name,
arguments=queue.bind_arguments,
timeout=queue.bind_timeout,
)

if self._delayed_message_exchange_plugin:
await declared_queue.bind(
exchange=self._delayed_message_exchange.name,
routing_key=queue.routing_key or queue.name,
)
else:
# Declare delay queue with x-dead-letter-routing-key to queue
per_queue_arguments["x-dead-letter-exchange"] = self._exchange.name
per_queue_arguments["x-dead-letter-routing-key"] = (
queue.queue_routing_key
)

logger.debug(
"Declare delay queue '%s' with 'x-dead-letter-routing-key': %s"
" and exchange '%s'",
queue.delay_queue_name,
queue.queue_routing_key,
self._exchange.name,
)
await channel.declare_queue(
name=queue.delay_queue_name,
durable=queue.durable,
exclusive=queue.exclusive,
passive=queue.passive,
auto_delete=queue.auto_delete,
arguments=per_queue_arguments,
timeout=queue.timeout,
)

declared_queues.append((declared_queue, queue.consumer_arguments))

for queue in filter(lambda queue: not queue.declare, queues):
Expand All @@ -330,6 +352,8 @@ def with_queue(self, queue: Queue) -> Self:
:return: self.
"""
self._task_queues.append(queue)
self._task_queues_by_routing_key[queue.queue_routing_key] = queue

return self

def with_queues(self, *queues: Queue) -> Self:
Expand All @@ -340,6 +364,10 @@ def with_queues(self, *queues: Queue) -> Self:
:return: self.
"""
self._task_queues = list(queues)
self._task_queues_by_routing_key = {
queue.queue_routing_key: queue for queue in self._task_queues
}

return self

async def kick(self, message: BrokerMessage) -> None:
Expand Down Expand Up @@ -373,9 +401,8 @@ async def kick(self, message: BrokerMessage) -> None:
delay = parse_val(float, message.labels.get("delay"))

if len(self._task_queues) == 1:
routing_key_name = (
self._task_queues[0].routing_key or self._task_queues[0].name
)
queue: Queue | None = self._task_queues[0]
routing_key_name = queue.queue_routing_key
else:
routing_key_name = (
parse_val(
Expand All @@ -384,9 +411,10 @@ async def kick(self, message: BrokerMessage) -> None:
)
or ""
)
if self._exchange.type == ExchangeType.DIRECT and routing_key_name not in {
queue.routing_key or queue.name for queue in self._task_queues
}:

queue: Queue | None = self._task_queues_by_routing_key.get(routing_key_name)

if self._exchange.type == ExchangeType.DIRECT and queue is None:
raise IncorrectRoutingKeyError(
f"Routing key '{routing_key_name}' is not valid. "
f"Check routing keys and queue names in broker queues.",
Expand All @@ -404,11 +432,11 @@ async def kick(self, message: BrokerMessage) -> None:
self._delayed_message_exchange.name,
)
await exchange.publish(rmq_message, routing_key=routing_key_name)
elif self._delay_queue:
elif queue is not None:
rmq_message.expiration = timedelta(seconds=delay)
await self.write_channel.default_exchange.publish(
rmq_message,
routing_key=self._delay_queue.routing_key or self._delay_queue.name,
routing_key=queue.delay_queue_routing_key,
)
else:
raise IncorrectRoutingKeyError(
Expand Down Expand Up @@ -450,7 +478,8 @@ async def body(
*[
body(queue, consumer_args)
for queue, consumer_args in queue_with_consumer_args_list
if not self._delay_queue or queue.name != self._delay_queue.name
if self._worker_queues_to_consume is None
or queue.name in self._worker_queues_to_consume
],
)

Expand Down
15 changes: 15 additions & 0 deletions taskiq_aio_pika/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,18 @@ class Queue:

# will be used during message consumption
consumer_arguments: FieldTable = field(default_factory=dict)

@property
def delay_queue_name(self) -> str:
"""Return the name of the delay queue for this queue."""
return f"{self.name}.delay"

@property
def delay_queue_routing_key(self) -> str:
"""Return the routing key used to publish messages to the delay queue."""
return self.delay_queue_name

@property
def queue_routing_key(self) -> str:
"""Return the effective routing key for this queue."""
return self.routing_key or self.name
18 changes: 1 addition & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,6 @@ def queue_name() -> str:
return uuid4().hex + "_queue"


@pytest.fixture
def delay_queue_name() -> str:
"""
Generated name for delay queue.

:return: random exchange name.
"""
return uuid4().hex + "_delay_queue"


@pytest.fixture
def dead_queue_name() -> str:
"""
Expand Down Expand Up @@ -141,7 +131,6 @@ async def _cleanup_amqp_resources(
async def broker(
amqp_url: str,
queue_name: str,
delay_queue_name: str,
dead_queue_name: str,
exchange_name: str,
test_channel: Channel,
Expand All @@ -157,11 +146,6 @@ async def broker(
declare=True,
type=QueueType.CLASSIC,
),
delay_queue=Queue(
name=delay_queue_name,
declare=True,
type=QueueType.CLASSIC,
),
task_queues=[
Queue(
name=queue_name,
Expand All @@ -181,7 +165,7 @@ async def broker(
await _cleanup_amqp_resources(
amqp_url,
[exchange_name],
[queue_name, delay_queue_name, dead_queue_name],
[queue_name, f"{queue_name}.delay", dead_queue_name],
)


Expand Down
3 changes: 1 addition & 2 deletions tests/test_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ async def test_when_delayed_message_queue_exists__then_send_with_delay_must_work
broker: AioPikaBroker,
test_channel: Channel,
queue_name: str,
delay_queue_name: str,
) -> None:
delay_queue = await test_channel.get_queue(delay_queue_name)
delay_queue = await test_channel.get_queue(f"{queue_name}.delay")
main_queue = await test_channel.get_queue(queue_name)
broker_msg = BrokerMessage(
task_id="1",
Expand Down
10 changes: 5 additions & 5 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ async def cleanup_class_broker(self, amqp_url: str) -> AsyncGenerator[None, None
yield
if self.broker is not None:
await self.broker.shutdown()
queue_names = [queue.name for queue in self.broker._task_queues] + [
self.broker._dead_letter_queue.name,
]
if self.broker._delay_queue is not None:
queue_names.append(self.broker._delay_queue.name)
queue_names = (
[queue.name for queue in self.broker._task_queues]
+ [queue.delay_queue_name for queue in self.broker._task_queues]
+ [self.broker._dead_letter_queue.name]
)
await _cleanup_amqp_resources(
amqp_url,
[self.broker._exchange.name],
Expand Down
Loading