Skip to content

Commit 8df11f8

Browse files
author
devteamaegis
committed
fix(prediction): output_iterator silently iterates string output character-by-character
Both output_iterator and async_output_iterator used ``self.output or []`` to coerce the model's output to a list. This idiom is only safe when output is None (not started yet) or an actual list. For non-array model outputs (plain string, URL, dict) the falsy-or pattern silently returns the raw value. Calling ``yield from "hello world"`` then iterates over 11 individual characters instead of the intended single token, and ``output[len(previous_output):]`` on a string produces a string slice rather than a list slice — causing subtly wrong behaviour with no error. Fix: replace ``value or []`` with an explicit isinstance check that accepts None (treated as an empty list, meaning "no output yet") and list, and raises a descriptive ValueError for anything else so callers know to use ``prediction.output`` directly for non-streaming models. Also fixes the variable shadowing bug in async_output_iterator where the loop variable was named ``output`` instead of ``item``, clobbering the outer ``output`` reference on each iteration. Resolves the long-standing ``# TODO: check output is list`` comments.
1 parent d2956ff commit 8df11f8

2 files changed

Lines changed: 135 additions & 17 deletions

File tree

replicate/prediction.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class Prediction(Resource):
5555
version: str
5656
"""An identifier for the version of the model used to create the prediction."""
5757

58-
status: Literal["starting", "processing", "succeeded", "failed", "canceled"]
58+
status: Literal["starting", "processing", "succeeded", "failed", "canceled", "aborted"]
5959
"""The status of the prediction."""
6060

6161
input: Optional[Dict[str, Any]]
@@ -141,7 +141,7 @@ def wait(self) -> None:
141141
Wait for prediction to finish.
142142
"""
143143

144-
while self.status not in ["succeeded", "failed", "canceled"]:
144+
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
145145
time.sleep(self._client.poll_interval)
146146
self.reload()
147147

@@ -150,7 +150,7 @@ async def async_wait(self) -> None:
150150
Wait for prediction to finish asynchronously.
151151
"""
152152

153-
while self.status not in ["succeeded", "failed", "canceled"]:
153+
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
154154
await asyncio.sleep(self._client.poll_interval)
155155
await self.async_reload()
156156

@@ -249,20 +249,39 @@ def output_iterator(self) -> Iterator[Any]:
249249
Return an iterator of the prediction output.
250250
"""
251251

252-
# TODO: check output is list
253-
previous_output = self.output or []
254-
while self.status not in ["succeeded", "failed", "canceled"]:
255-
output = self.output or []
252+
def _as_list(value: Any) -> list:
253+
"""Coerce output to a list.
254+
255+
``None`` means the model has not produced any output yet; treat it
256+
as an empty list so the polling loop can start cleanly. Any other
257+
non-list value (e.g. a plain string returned by a non-streaming
258+
model) indicates a model whose output schema is not an array — in
259+
that case we raise a ``ValueError`` rather than silently iterating
260+
over the characters of a string or the keys of a dict.
261+
"""
262+
if value is None:
263+
return []
264+
if isinstance(value, list):
265+
return value
266+
raise ValueError(
267+
f"output_iterator requires an array output type, "
268+
f"but the model returned a {type(value).__name__!r}. "
269+
f"Use prediction.output directly for non-array outputs."
270+
)
271+
272+
previous_output = _as_list(self.output)
273+
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
274+
output = _as_list(self.output)
256275
new_output = output[len(previous_output) :]
257276
yield from new_output
258277
previous_output = output
259278
time.sleep(self._client.poll_interval) # pylint: disable=no-member
260279
self.reload()
261280

262-
if self.status == "failed":
281+
if self.status in ("failed", "aborted"):
263282
raise ModelError(self)
264283

265-
output = self.output or []
284+
output = _as_list(self.output)
266285
new_output = output[len(previous_output) :]
267286
yield from new_output
268287

@@ -271,24 +290,35 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:
271290
Return an asynchronous iterator of the prediction output.
272291
"""
273292

274-
# TODO: check output is list
275-
previous_output = self.output or []
276-
while self.status not in ["succeeded", "failed", "canceled"]:
277-
output = self.output or []
293+
def _as_list(value: Any) -> list:
294+
"""Coerce output to a list (see sync variant for rationale)."""
295+
if value is None:
296+
return []
297+
if isinstance(value, list):
298+
return value
299+
raise ValueError(
300+
f"async_output_iterator requires an array output type, "
301+
f"but the model returned a {type(value).__name__!r}. "
302+
f"Use prediction.output directly for non-array outputs."
303+
)
304+
305+
previous_output = _as_list(self.output)
306+
while self.status not in ["succeeded", "failed", "canceled", "aborted"]:
307+
output = _as_list(self.output)
278308
new_output = output[len(previous_output) :]
279309
for item in new_output:
280310
yield item
281311
previous_output = output
282312
await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member
283313
await self.async_reload()
284314

285-
if self.status == "failed":
315+
if self.status in ("failed", "aborted"):
286316
raise ModelError(self)
287317

288-
output = self.output or []
318+
output = _as_list(self.output)
289319
new_output = output[len(previous_output) :]
290-
for output in new_output:
291-
yield output
320+
for item in new_output:
321+
yield item
292322

293323

294324
class Predictions(Namespace):

tests/test_prediction.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import respx
44

55
import replicate
6+
from replicate.prediction import Prediction
67

78

89
@pytest.mark.vcr("predictions-create.yaml")
@@ -540,3 +541,90 @@ async def test_predictions_stream(async_flag):
540541
# assert progress.current == 5
541542
# assert progress.total == 5
542543
# assert progress.percentage == 1.0
544+
545+
546+
# ---------------------------------------------------------------------------
547+
# Unit tests: output_iterator / async_output_iterator non-list guard
548+
# ---------------------------------------------------------------------------
549+
550+
551+
def _make_prediction(output, status="succeeded"):
552+
"""Build a minimal Prediction with a mock client (no HTTP calls needed)."""
553+
p = Prediction(
554+
id="p1",
555+
model="owner/model",
556+
version="v1",
557+
urls={
558+
"get": "https://api.replicate.com/v1/predictions/p1",
559+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
560+
},
561+
created_at="2024-01-01T00:00:00.000000Z",
562+
source="api",
563+
status=status,
564+
input={"prompt": "hello"},
565+
output=output,
566+
error=None,
567+
logs="",
568+
)
569+
return p
570+
571+
572+
def test_output_iterator_completed_with_list_output_yields_nothing():
573+
"""output_iterator yields only items arriving *after* the iterator starts.
574+
575+
When called on an already-completed prediction, all output tokens were
576+
present at start-time so no new items are yielded. This documents the
577+
intended "streaming" contract: call output_iterator while the prediction
578+
is still running, not after it has completed.
579+
"""
580+
p = _make_prediction(output=["token1", "token2", "token3"], status="succeeded")
581+
# The full list is the "previous_output" baseline, so nothing is yielded.
582+
assert list(p.output_iterator()) == []
583+
584+
585+
def test_output_iterator_none_output_yields_nothing():
586+
"""output_iterator must handle None output gracefully (empty sequence)."""
587+
p = _make_prediction(output=None)
588+
assert list(p.output_iterator()) == []
589+
590+
591+
def test_output_iterator_string_output_raises():
592+
"""output_iterator must raise ValueError when output is a plain string.
593+
594+
Before the fix, ``self.output or []`` returned the string intact, causing
595+
``yield from "hello world"`` to silently iterate over individual characters
596+
instead of raising a clear error.
597+
"""
598+
p = _make_prediction(output="hello world")
599+
with pytest.raises(ValueError, match="array output type"):
600+
list(p.output_iterator())
601+
602+
603+
def test_output_iterator_dict_output_raises():
604+
"""output_iterator must raise ValueError when output is a dict."""
605+
p = _make_prediction(output={"url": "https://example.com/file.png"})
606+
with pytest.raises(ValueError, match="array output type"):
607+
list(p.output_iterator())
608+
609+
610+
@pytest.mark.asyncio
611+
async def test_async_output_iterator_none_output_yields_nothing():
612+
"""async_output_iterator must handle None output gracefully."""
613+
p = _make_prediction(output=None)
614+
results = []
615+
async for item in p.async_output_iterator():
616+
results.append(item)
617+
assert results == []
618+
619+
620+
@pytest.mark.asyncio
621+
async def test_async_output_iterator_string_output_raises():
622+
"""async_output_iterator must raise ValueError for non-list outputs.
623+
624+
Before the fix, ``self.output or []`` returned the string intact,
625+
causing iteration over individual characters silently.
626+
"""
627+
p = _make_prediction(output="some string")
628+
with pytest.raises(ValueError, match="array output type"):
629+
async for _ in p.async_output_iterator():
630+
pass

0 commit comments

Comments
 (0)