diff --git a/sql/schema.sql b/sql/schema.sql index fc678bc..ac2f42e 100644 --- a/sql/schema.sql +++ b/sql/schema.sql @@ -611,6 +611,33 @@ begin end; $$; +-- Counts unclaimed runs that are ready to be claimed. +-- Uses the same candidate logic as claim_task but without locking or updating. +create function durable.count_unclaimed_ready_tasks ( + p_queue_name text +) + returns bigint + language plpgsql +as $$ +declare + v_now timestamptz := durable.current_time(); + v_count bigint; +begin + execute format( + 'select count(*) + from durable.%1$I r + join durable.%2$I t on t.task_id = r.task_id + where r.state in (''pending'', ''sleeping'') + and t.state in (''pending'', ''sleeping'', ''running'') + and r.available_at <= $1', + 'r_' || p_queue_name, + 't_' || p_queue_name + ) into v_count using v_now; + + return v_count; +end; +$$; + -- Marks a run as completed create function durable.complete_run ( p_queue_name text, diff --git a/src/client.rs b/src/client.rs index 58d0346..4da6686 100644 --- a/src/client.rs +++ b/src/client.rs @@ -721,6 +721,22 @@ where Ok(()) } + /// Count unclaimed tasks that are ready to be claimed in a queue. + /// + /// All of these tasks can be claimed by a running worker on the provided queue. + pub async fn count_unclaimed_ready_tasks( + &self, + queue_name: Option<&str>, + ) -> DurableResult { + let queue = queue_name.unwrap_or(&self.queue_name); + let query = "SELECT durable.count_unclaimed_ready_tasks($1)"; + let (count,): (i64,) = sqlx::query_as(query) + .bind(queue) + .fetch_one(&self.pool) + .await?; + Ok(count) + } + /// Cancel a task by ID. Running tasks will be cancelled at /// their next checkpoint or heartbeat. pub async fn cancel_task(&self, task_id: Uuid, queue_name: Option<&str>) -> DurableResult<()> { diff --git a/src/postgres/migrations/20260317153556_add_count_unclaimed_ready_tasks.sql b/src/postgres/migrations/20260317153556_add_count_unclaimed_ready_tasks.sql new file mode 100644 index 0000000..e01a9c4 --- /dev/null +++ b/src/postgres/migrations/20260317153556_add_count_unclaimed_ready_tasks.sql @@ -0,0 +1,26 @@ +-- Add function to count unclaimed runs that are ready to be claimed. +-- Uses the same candidate logic as claim_task but without locking or updating. +create or replace function durable.count_unclaimed_ready_tasks ( + p_queue_name text +) + returns bigint + language plpgsql +as $$ +declare + v_now timestamptz := durable.current_time(); + v_count bigint; +begin + execute format( + 'select count(*) + from durable.%1$I r + join durable.%2$I t on t.task_id = r.task_id + where r.state in (''pending'', ''sleeping'') + and t.state in (''pending'', ''sleeping'', ''running'') + and r.available_at <= $1', + 'r_' || p_queue_name, + 't_' || p_queue_name + ) into v_count using v_now; + + return v_count; +end; +$$; diff --git a/tests/count_unclaimed_test.rs b/tests/count_unclaimed_test.rs new file mode 100644 index 0000000..bad2a13 --- /dev/null +++ b/tests/count_unclaimed_test.rs @@ -0,0 +1,223 @@ +#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] + +mod common; + +use common::tasks::{EchoParams, EchoTask}; +use durable::{Durable, DurableBuilder, MIGRATOR, WorkerOptions}; +use sqlx::{AssertSqlSafe, PgPool}; +use std::time::Duration; + +/// Helper to create a DurableBuilder from the test pool. +fn create_client(pool: PgPool, queue_name: &str) -> DurableBuilder { + Durable::builder().pool(pool).queue_name(queue_name) +} + +// ============================================================================ +// count_unclaimed_ready_tasks Tests +// ============================================================================ + +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_count_unclaimed_empty_queue(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "count_empty") + .register::() + .unwrap() + .build() + .await + .unwrap(); + client.create_queue(None).await.unwrap(); + + let count = client + .count_unclaimed_ready_tasks(None) + .await + .expect("Failed to count unclaimed tasks"); + assert_eq!(count, 0, "Empty queue should have 0 unclaimed tasks"); + + Ok(()) +} + +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_count_unclaimed_after_spawning(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "count_spawn") + .register::() + .unwrap() + .build() + .await + .unwrap(); + client.create_queue(None).await.unwrap(); + + // Spawn 3 tasks + for i in 0..3 { + client + .spawn::(EchoParams { + message: format!("task {i}"), + }) + .await + .expect("Failed to spawn task"); + } + + let count = client + .count_unclaimed_ready_tasks(None) + .await + .expect("Failed to count unclaimed tasks"); + assert_eq!(count, 3, "Should have 3 unclaimed tasks after spawning 3"); + + Ok(()) +} + +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_count_unclaimed_decreases_after_claim(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "count_claim") + .register::() + .unwrap() + .build() + .await + .unwrap(); + client.create_queue(None).await.unwrap(); + + // Spawn 3 tasks + for i in 0..3 { + client + .spawn::(EchoParams { + message: format!("task {i}"), + }) + .await + .expect("Failed to spawn task"); + } + + assert_eq!( + client.count_unclaimed_ready_tasks(None).await.unwrap(), + 3, + "Should start with 3 unclaimed" + ); + + // Start a worker that will claim and complete tasks + let worker = client + .start_worker(WorkerOptions { + concurrency: 3, + poll_interval: Duration::from_millis(100), + ..Default::default() + }) + .await + .unwrap(); + + // Wait for the worker to process all tasks + tokio::time::sleep(Duration::from_secs(2)).await; + + let count = client + .count_unclaimed_ready_tasks(None) + .await + .expect("Failed to count unclaimed tasks"); + assert_eq!( + count, 0, + "Should have 0 unclaimed tasks after worker claims them" + ); + + worker.shutdown().await; + + Ok(()) +} + +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_count_unclaimed_with_explicit_queue_name(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "default") + .register::() + .unwrap() + .build() + .await + .unwrap(); + + // Create two queues + client.create_queue(Some("queue_a")).await.unwrap(); + client.create_queue(Some("queue_b")).await.unwrap(); + + // Spawn tasks into queue_a + let client_a = create_client(pool.clone(), "queue_a") + .register::() + .unwrap() + .build() + .await + .unwrap(); + for i in 0..2 { + client_a + .spawn::(EchoParams { + message: format!("a-{i}"), + }) + .await + .unwrap(); + } + + // Spawn tasks into queue_b + let client_b = create_client(pool.clone(), "queue_b") + .register::() + .unwrap() + .build() + .await + .unwrap(); + for i in 0..5 { + client_b + .spawn::(EchoParams { + message: format!("b-{i}"), + }) + .await + .unwrap(); + } + + // Count using explicit queue names + let count_a = client + .count_unclaimed_ready_tasks(Some("queue_a")) + .await + .unwrap(); + let count_b = client + .count_unclaimed_ready_tasks(Some("queue_b")) + .await + .unwrap(); + + assert_eq!(count_a, 2, "queue_a should have 2 unclaimed tasks"); + assert_eq!(count_b, 5, "queue_b should have 5 unclaimed tasks"); + + Ok(()) +} + +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_count_unclaimed_excludes_future_tasks(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "count_future") + .register::() + .unwrap() + .build() + .await + .unwrap(); + client.create_queue(None).await.unwrap(); + + // Spawn two tasks (both ready now) + client + .spawn::(EchoParams { + message: "ready now".to_string(), + }) + .await + .unwrap(); + let delayed = client + .spawn::(EchoParams { + message: "will be delayed".to_string(), + }) + .await + .unwrap(); + + // Push one run's available_at into the future via direct SQL + sqlx::query(AssertSqlSafe( + "UPDATE durable.r_count_future SET available_at = now() + interval '1 hour' WHERE task_id = $1".to_string() + )) + .bind(delayed.task_id) + .execute(&pool) + .await?; + + let count = client + .count_unclaimed_ready_tasks(None) + .await + .expect("Failed to count unclaimed tasks"); + assert_eq!( + count, 1, + "Should only count the immediately-ready task, not the delayed one" + ); + + Ok(()) +}