diff --git a/queue.go b/queue.go index 0fe4161..0b709fe 100644 --- a/queue.go +++ b/queue.go @@ -30,6 +30,7 @@ type ( metric *metric // Metrics collector for tracking queue and worker stats logger Logger // Logger for queue events and errors workerCount int64 // Number of worker goroutines to process jobs + activeSlots int64 // Reserved worker slots; see tryReserveSlot/releaseSlot routineGroup *routineGroup // Group to manage and wait for goroutines quit chan struct{} // Channel to signal shutdown to all goroutines ready chan struct{} // Channel to signal worker readiness @@ -185,11 +186,11 @@ func (q *Queue) work(task core.TaskMessage) { // Defer block to handle panics, update metrics, and run afterFn callback. defer func() { q.metric.DecBusyWorker() + q.releaseSlot() e := recover() if e != nil { q.logger.Fatalf("panic error: %v", e) } - q.schedule() // Update success or failure metrics based on execution result. if err == nil && e == nil { @@ -313,21 +314,47 @@ func (q *Queue) UpdateWorkerCount(num int64) { q.schedule() } -// schedule checks if more workers can be started based on the current busy count. +// schedule checks if more workers can be started based on reserved slots. // If so, it signals readiness to start a new worker. func (q *Queue) schedule() { q.Lock() defer q.Unlock() - if q.BusyWorkers() >= q.workerCount { + q.signalReadyLocked() +} + +// signalReadyLocked sends a non-blocking ready signal if a slot is available. +// Caller must hold q.Lock(). +func (q *Queue) signalReadyLocked() { + if q.activeSlots >= q.workerCount { return } - select { case q.ready <- struct{}{}: default: } } +// tryReserveSlot reserves a worker slot if one is available under the mutex, +// closing the TOCTOU gap between schedule() and dispatch. +func (q *Queue) tryReserveSlot() bool { + q.Lock() + defer q.Unlock() + if q.activeSlots >= q.workerCount { + return false + } + q.activeSlots++ + return true +} + +// releaseSlot frees a reserved slot and signals readiness in one critical +// section, saving a lock round-trip versus separate decrement + schedule(). +func (q *Queue) releaseSlot() { + q.Lock() + defer q.Unlock() + q.activeSlots-- + q.signalReadyLocked() +} + /* start launches the main worker loop, which manages job scheduling and execution. @@ -351,6 +378,10 @@ func (q *Queue) start() { return } + if !q.tryReserveSlot() { + continue + } + // Request a task from the worker in a background goroutine. q.routineGroup.Run(func() { for { @@ -386,6 +417,7 @@ func (q *Queue) start() { task, ok := <-tasks if !ok { + q.releaseSlot() return } diff --git a/ring_test.go b/ring_test.go index b264c21..4b4af22 100644 --- a/ring_test.go +++ b/ring_test.go @@ -6,6 +6,8 @@ import ( "fmt" "log" "runtime" + "sync" + "sync/atomic" "testing" "time" @@ -550,3 +552,74 @@ func BenchmarkRingQueue(b *testing.B) { } }) } + +func TestBusyWorkersNeverExceedsWorkerCount(t *testing.T) { + const workerCount = 4 + const totalTasks = 100 + + var maxObserved int64 + var wg sync.WaitGroup + wg.Add(totalTasks) + gate := make(chan struct{}) + + w := NewRing( + WithFn(func(ctx context.Context, m core.TaskMessage) error { + defer wg.Done() + select { + case <-gate: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }), + ) + q, err := NewQueue( + WithWorker(w), + WithWorkerCount(workerCount), + ) + assert.NoError(t, err) + + q.Start() + for i := 0; i < totalTasks; i++ { + assert.NoError(t, q.Queue(mockMessage{message: "task"})) + } + + // Continuously monitor BusyWorkers while tasks execute. + stop := make(chan struct{}) + monitorDone := make(chan struct{}) + go func() { + defer close(monitorDone) + ticker := time.NewTicker(100 * time.Microsecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + busy := q.BusyWorkers() + for { + old := atomic.LoadInt64(&maxObserved) + if busy <= old || atomic.CompareAndSwapInt64(&maxObserved, old, busy) { + break + } + } + } + } + }() + + // Release tasks with a timeout to prevent hanging on regression. + timeout := time.After(10 * time.Second) + for i := 0; i < totalTasks; i++ { + select { + case gate <- struct{}{}: + case <-timeout: + t.Fatal("timed out sending gate tokens — possible scheduling deadlock") + } + } + wg.Wait() + close(stop) + <-monitorDone + q.Release() + + assert.LessOrEqual(t, maxObserved, int64(workerCount)) +}