Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,19 @@ def __init__(self, config: Dict[str, Any]):
self.jobs_handler = jobs_handler

async def set_scale(self):
self.current_concurrency = self.concurrency_modifier(self.current_concurrency)
# Concurrency modifier is user-provided and can return invalid values (e.g. None).
# Defensive validation prevents crashes like: TypeError: '<' not supported between 'int' and 'NoneType'
# when current_concurrency is used for queue sizing / task scheduling.
try:
new_concurrency = self.concurrency_modifier(self.current_concurrency)
except Exception as error:
log.warn(
f"JobScaler.set_scale | concurrency_modifier raised {type(error).__name__}: {error}. "
f"Keeping concurrency at {self.current_concurrency}."
)
new_concurrency = self.current_concurrency

self.current_concurrency = self._sanitize_concurrency(new_concurrency)

if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize):
# no need to resize
Expand All @@ -88,6 +100,34 @@ async def set_scale(self):
f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}"
)

@staticmethod
def _sanitize_concurrency(value: Any) -> int:
"""
Coerce a user-provided concurrency value into a safe integer >= 1.
"""
# Reject common footguns explicitly.
if value is None or isinstance(value, bool) or isinstance(value, float):
log.warn(
f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1."
)
return 1

try:
v = int(value)
except Exception:
log.warn(
f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1."
)
return 1

if v < 1:
log.warn(
f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1."
)
return 1

return v

def start(self):
"""
This is required for the worker to be able to shut down gracefully
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio
from unittest import TestCase

from runpod.serverless.modules.rp_scale import JobScaler


class TestJobScalerConcurrencyValidation(TestCase):
def test_concurrency_modifier_none_defaults_to_one(self):
scaler = JobScaler({"concurrency_modifier": lambda _: None})

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[🟡 Medium] [🔵 Bug]

The new tests instantiate JobScaler and call set_scale() without clearing JobsProgress, but set_scale() waits in a loop while current_occupancy() > 0; because JobsProgress is a singleton with file-backed persistence, stale in-flight jobs from prior tests/runs can make these tests hang indefinitely and create non-deterministic CI behavior. Add per-test setup that clears/reset JobsProgress (or monkeypatch JobScaler.job_progress) before invoking set_scale().

# tests/test_serverless/test_modules/test_rp_scale_concurrency_validation.py
scaler = JobScaler({"concurrency_modifier": lambda _: None})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 1)

asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 1)
Comment on lines +8 to +11

def test_concurrency_modifier_zero_defaults_to_one(self):
scaler = JobScaler({"concurrency_modifier": lambda _: 0})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 1)

def test_concurrency_modifier_negative_defaults_to_one(self):
scaler = JobScaler({"concurrency_modifier": lambda _: -3})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 1)

def test_concurrency_modifier_valid_int_is_applied(self):
scaler = JobScaler({"concurrency_modifier": lambda _: 4})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 4)