diff --git a/tests/integration/_utils.py b/tests/integration/_utils.py index 2d941af7..dd6cca92 100644 --- a/tests/integration/_utils.py +++ b/tests/integration/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import time from typing import TYPE_CHECKING, Literal, TypeVar from crawlee._utils.crypto import crypto_random_object_id @@ -48,6 +49,32 @@ async def call_with_exp_backoff( raise ValueError(f'Invalid rq_access_mode: {rq_access_mode}') +async def poll_until_condition( + fn: Callable[[], Awaitable[T]], + condition: Callable[[T], bool], + *, + timeout: float = 60, + poll_interval: float = 5, +) -> T: + """Poll `fn` until `condition(result)` is True or the timeout expires. + + Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed. + Returns the last polled result regardless of whether the condition was met. + + Use this instead of a fixed `asyncio.sleep` when waiting for eventually-consistent API state (e.g. request queue + stats) that may take a variable amount of time to propagate. + """ + deadline = time.monotonic() + timeout + result = await fn() + while not condition(result): + remaining = deadline - time.monotonic() + if remaining <= 0: + break + await asyncio.sleep(min(poll_interval, remaining)) + result = await fn() + return result + + def generate_unique_resource_name(label: str) -> str: """Generates a unique resource name, which will contain the given label.""" name_template = 'python-sdk-tests-{}-generated-{}' diff --git a/tests/integration/test_request_queue.py b/tests/integration/test_request_queue.py index b38e4ae5..25f4caf6 100644 --- a/tests/integration/test_request_queue.py +++ b/tests/integration/test_request_queue.py @@ -12,7 +12,7 @@ from crawlee import service_locator from crawlee.crawlers import BasicCrawler -from ._utils import call_with_exp_backoff, generate_unique_resource_name +from ._utils import call_with_exp_backoff, generate_unique_resource_name, poll_until_condition from apify import Actor, Request from apify.storage_clients import ApifyStorageClient from apify.storage_clients._apify import ApifyRequestQueueClient @@ -856,10 +856,9 @@ async def test_request_queue_metadata_another_client( api_client = apify_client_async.request_queue(request_queue_id=rq.id, client_key=None) await api_client.add_request(Request.from_url('http://example.com/1').model_dump(by_alias=True, exclude={'id'})) - # Wait to be sure that the API has updated the global metadata - await asyncio.sleep(10) - - assert (await rq.get_metadata()).total_request_count == 1 + # Poll until the API has propagated the metadata change. + metadata = await poll_until_condition(rq.get_metadata, lambda m: m.total_request_count >= 1) + assert metadata.total_request_count == 1 async def test_request_queue_had_multiple_clients( @@ -950,12 +949,18 @@ async def default_handler(context: BasicCrawlingContext) -> None: assert crawler.statistics.state.requests_finished == requests try: - # Check the request queue stats - await asyncio.sleep(10) # Wait to be sure that metadata are updated + # Poll until request queue stats are propagated by the API. + expected_write_count = requests * expected_write_count_per_request + + async def _get_rq_metadata() -> ApifyRequestQueueMetadata: + return cast('ApifyRequestQueueMetadata', await rq.get_metadata()) - metadata = cast('ApifyRequestQueueMetadata', await rq.get_metadata()) + metadata = await poll_until_condition( + _get_rq_metadata, + lambda m: m.stats.write_count >= expected_write_count, + ) Actor.log.info(f'{metadata.stats=}') - assert metadata.stats.write_count == requests * expected_write_count_per_request + assert metadata.stats.write_count == expected_write_count finally: await rq.drop() @@ -1009,13 +1014,16 @@ async def test_request_queue_has_stats(request_queue_apify: RequestQueue) -> Non await rq.add_requests([Request.from_url(f'http://example.com/{i}') for i in range(add_request_count)]) - # Wait for stats to become stable - await asyncio.sleep(10) + # Poll until stats are propagated by the API. + async def _get_rq_metadata() -> ApifyRequestQueueMetadata: + return cast('ApifyRequestQueueMetadata', await rq.get_metadata()) - metadata = await rq.get_metadata() + apify_metadata = await poll_until_condition( + _get_rq_metadata, + lambda m: m.stats.write_count >= add_request_count, + ) - assert hasattr(metadata, 'stats') - apify_metadata = cast('ApifyRequestQueueMetadata', metadata) + assert hasattr(apify_metadata, 'stats') assert apify_metadata.stats.write_count == add_request_count @@ -1153,10 +1161,15 @@ def return_unprocessed_requests(requests: list[dict], *_: Any, **__: Any) -> dic # This will succeed. await request_queue_apify.add_requests(['http://example.com/1']) - await asyncio.sleep(10) # Wait to be sure that metadata are updated - _rq = await rq_client.get() - assert _rq - stats_after = _rq.get('stats', {}) + # Poll until stats reflect the successful write. + async def _get_rq_stats() -> dict: + result = await rq_client.get() + return (result or {}).get('stats', {}) + + stats_after = await poll_until_condition( + _get_rq_stats, + lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= 1, + ) Actor.log.info(stats_after) assert (stats_after['writeCount'] - stats_before['writeCount']) == 1 @@ -1256,10 +1269,15 @@ async def test_request_queue_deduplication( await rq.add_request(request1) await rq.add_request(request2) - await asyncio.sleep(10) # Wait to be sure that metadata are updated - _rq = await rq_client.get() - assert _rq - stats_after = _rq.get('stats', {}) + # Poll until stats reflect the write. + async def _get_rq_stats() -> dict: + result = await rq_client.get() + return (result or {}).get('stats', {}) + + stats_after = await poll_until_condition( + _get_rq_stats, + lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= 1, + ) assert (stats_after['writeCount'] - stats_before['writeCount']) == 1 @@ -1283,10 +1301,15 @@ async def test_request_queue_deduplication_use_extended_unique_key( await rq.add_request(request1) await rq.add_request(request2) - await asyncio.sleep(10) # Wait to be sure that metadata are updated - _rq = await rq_client.get() - assert _rq - stats_after = _rq.get('stats', {}) + # Poll until stats reflect both writes. + async def _get_rq_stats() -> dict: + result = await rq_client.get() + return (result or {}).get('stats', {}) + + stats_after = await poll_until_condition( + _get_rq_stats, + lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= 2, + ) assert (stats_after['writeCount'] - stats_before['writeCount']) == 2 @@ -1316,10 +1339,15 @@ async def add_requests_worker() -> None: add_requests_workers = [asyncio.create_task(add_requests_worker()) for _ in range(worker_count)] await asyncio.gather(*add_requests_workers) - await asyncio.sleep(10) # Wait to be sure that metadata are updated - _rq = await rq_client.get() - assert _rq - stats_after = _rq.get('stats', {}) + # Poll until stats reflect all written requests. + async def _get_rq_stats() -> dict: + result = await rq_client.get() + return (result or {}).get('stats', {}) + + stats_after = await poll_until_condition( + _get_rq_stats, + lambda s: s.get('writeCount', 0) - stats_before.get('writeCount', 0) >= len(requests), + ) assert (stats_after['writeCount'] - stats_before['writeCount']) == len(requests)