Skip to content

Commit e8a3fd2

Browse files
Cooperative suspensions (#192)
* [WIP] Cooperative suspensions * Bump shared core
1 parent 2b2d297 commit e8a3fd2

8 files changed

Lines changed: 169 additions & 596 deletions

File tree

Cargo.lock

Lines changed: 11 additions & 506 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ doc = false
1414
[dependencies]
1515
pyo3 = { version = "0.25.1", features = ["extension-module"] }
1616
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
17-
restate-sdk-shared-core = { version = "=0.9.0", features = ["request_identity", "sha2_random_seed"] }
17+
restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", rev = "d8a42ecceab6e7874138b6316e128a09f2de76d1", features = ["request_identity", "sha2_random_seed"] }

python/restate/asyncio.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from restate.exceptions import TerminalError
1717
from restate.context import RestateDurableFuture
1818
from restate.server_context import ServerDurableFuture, ServerInvocationContext
19+
from restate.vm import (
20+
AllCompletedUnresolvedFuture,
21+
FirstCompletedUnresolvedFuture,
22+
SingleUnresolvedFuture,
23+
)
1924

2025

2126
async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]:
@@ -24,9 +29,28 @@ async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFutu
2429
2530
Returns a list of all futures.
2631
"""
27-
async for _ in as_completed(*futures):
28-
pass
29-
return list(futures)
32+
context: ServerInvocationContext | None = None
33+
handles: List[int] = []
34+
futures_list = list(futures)
35+
36+
if not futures_list:
37+
return []
38+
for f in futures_list:
39+
if not isinstance(f, ServerDurableFuture):
40+
raise TerminalError("All futures must SDK created futures.")
41+
if context is None:
42+
context = f.context
43+
elif context is not f.context:
44+
raise TerminalError("All futures must be created by the same SDK context.")
45+
if not f.is_completed():
46+
handles.append(f.handle)
47+
48+
if handles:
49+
assert context is not None
50+
await context.create_poll_or_cancel_coroutine(
51+
AllCompletedUnresolvedFuture([SingleUnresolvedFuture(h) for h in handles])
52+
)
53+
return futures_list
3054

3155

3256
async def select(**kws: RestateDurableFuture[Any]) -> List[Any]:
@@ -118,7 +142,9 @@ async def wait_completed(
118142
completed = []
119143
uncompleted = []
120144
assert context is not None
121-
await context.create_poll_or_cancel_coroutine(handles)
145+
await context.create_poll_or_cancel_coroutine(
146+
FirstCompletedUnresolvedFuture([SingleUnresolvedFuture(h) for h in handles])
147+
)
122148

123149
for index, handle in enumerate(handles):
124150
future = futures[index]

python/restate/discovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,4 +429,4 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as: typing.Literal["
429429
protocol_mode = PROTOCOL_MODES[endpoint.protocol]
430430
else:
431431
protocol_mode = PROTOCOL_MODES[discovered_as]
432-
return Endpoint(protocolMode=protocol_mode, minProtocolVersion=5, maxProtocolVersion=6, services=services)
432+
return Endpoint(protocolMode=protocol_mode, minProtocolVersion=5, maxProtocolVersion=7, services=services)

python/restate/server_context.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@
6161
from restate.vm import (
6262
DoProgressAnyCompleted,
6363
DoProgressCancelSignalReceived,
64-
DoProgressReadFromInput,
64+
DoProgressWaitExternalProgress,
6565
DoProgressExecuteRun,
66-
DoWaitPendingRun,
66+
SingleUnresolvedFuture,
67+
UnresolvedFuture,
6768
)
6869

6970
logger = logging.getLogger(__name__)
@@ -193,7 +194,7 @@ def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
193194

194195
async def coro() -> str:
195196
if not context.vm.is_completed(handle):
196-
await context.create_poll_or_cancel_coroutine([handle])
197+
await context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
197198
invocation_id = await context.must_take_notification(handle)
198199
return typing.cast(str, invocation_id)
199200

@@ -235,7 +236,7 @@ def resolve(self, value: Any) -> Awaitable[None]:
235236

236237
async def await_point():
237238
if not self.server_context.vm.is_completed(handle):
238-
await self.server_context.create_poll_or_cancel_coroutine([handle])
239+
await self.server_context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
239240
await self.server_context.must_take_notification(handle)
240241

241242
return ServerDurableFuture(self.server_context, handle, await_point)
@@ -248,7 +249,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]:
248249

249250
async def await_point():
250251
if not self.server_context.vm.is_completed(handle):
251-
await self.server_context.create_poll_or_cancel_coroutine([handle])
252+
await self.server_context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
252253
await self.server_context.must_take_notification(handle)
253254

254255
return ServerDurableFuture(self.server_context, handle, await_point)
@@ -527,11 +528,11 @@ async def must_take_notification(self, handle):
527528
raise TerminalError(res.message, res.code)
528529
return res
529530

530-
async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None:
531-
"""Create a coroutine to poll the handle."""
531+
async def create_poll_or_cancel_coroutine(self, unresolved_future: UnresolvedFuture) -> None:
532+
"""Create a coroutine to poll the unresolved future."""
532533
while True:
533534
await self.take_and_send_output()
534-
do_progress_response = self.vm.do_progress(handles)
535+
do_progress_response = self.vm.do_progress(unresolved_future)
535536
if isinstance(do_progress_response, BaseException):
536537
logger.exception("Exception in do_progress", exc_info=do_progress_response)
537538
raise SdkInternalException() from do_progress_response
@@ -556,7 +557,7 @@ async def wrapper(f):
556557
task = asyncio.create_task(wrapper(fn))
557558
self.tasks.add(task)
558559
continue
559-
if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
560+
if isinstance(do_progress_response, DoProgressWaitExternalProgress):
560561
chunk = await self.receive()
561562
if chunk.get("type") == "restate.run_completed":
562563
continue
@@ -574,7 +575,7 @@ def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = N
574575

575576
async def fetch_result():
576577
if not self.vm.is_completed(handle):
577-
await self.create_poll_or_cancel_coroutine([handle])
578+
await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
578579
res = await self.must_take_notification(handle)
579580
if res is None or serde is None:
580581
return res
@@ -593,7 +594,7 @@ def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture:
593594

594595
async def transform():
595596
if not self.vm.is_completed(handle):
596-
await self.create_poll_or_cancel_coroutine([handle])
597+
await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
597598
await self.must_take_notification(handle)
598599

599600
return ServerDurableSleepFuture(self, handle, transform)
@@ -605,7 +606,7 @@ def create_call_future(
605606

606607
async def inv_id_factory():
607608
if not self.vm.is_completed(invocation_id_handle):
608-
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
609+
await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(invocation_id_handle))
609610
return await self.must_take_notification(invocation_id_handle)
610611

611612
return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory)

python/restate/vm.py

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
# pylint: disable=E1101,R0917
1616
# pylint: disable=too-many-arguments
1717
# pylint: disable=too-few-public-methods
18-
from typing import Optional
18+
from typing import List, Optional, Union
1919
from datetime import timedelta
2020

21-
from dataclasses import dataclass
21+
from dataclasses import dataclass, field
2222
import typing
2323
from restate._internal import (
2424
PyVM,
@@ -30,10 +30,10 @@
3030
PyStateKeys,
3131
PyExponentialRetryConfig,
3232
PyDoProgressAnyCompleted,
33-
PyDoProgressReadFromInput,
33+
PyDoProgressWaitExternalProgress,
3434
PyDoProgressExecuteRun,
35-
PyDoWaitForPendingRun,
3635
PyDoProgressCancelSignalReceived,
36+
PyUnresolvedFuture,
3737
CANCEL_NOTIFICATION_HANDLE,
3838
) # pylint: disable=import-error,no-name-in-module,line-too-long
3939

@@ -105,9 +105,10 @@ class DoProgressAnyCompleted:
105105
"""
106106

107107

108-
class DoProgressReadFromInput:
108+
class DoProgressWaitExternalProgress:
109109
"""
110-
Represents a notification that the input needs to be read.
110+
Represents a notification that external progress is required
111+
(either new input from the server or a pending run proposal).
111112
"""
112113

113114

@@ -128,26 +129,57 @@ class DoProgressCancelSignalReceived:
128129
"""
129130

130131

131-
class DoWaitPendingRun:
132-
"""
133-
Represents a notification that a run is pending
134-
"""
135-
136-
137132
DO_PROGRESS_ANY_COMPLETED = DoProgressAnyCompleted()
138-
DO_PROGRESS_READ_FROM_INPUT = DoProgressReadFromInput()
133+
DO_PROGRESS_WAIT_EXTERNAL_PROGRESS = DoProgressWaitExternalProgress()
139134
DO_PROGRESS_CANCEL_SIGNAL_RECEIVED = DoProgressCancelSignalReceived()
140-
DO_WAIT_PENDING_RUN = DoWaitPendingRun()
141135

142136
DoProgressResult = typing.Union[
143137
DoProgressAnyCompleted,
144-
DoProgressReadFromInput,
138+
DoProgressWaitExternalProgress,
145139
DoProgressExecuteRun,
146140
DoProgressCancelSignalReceived,
147-
DoWaitPendingRun,
148141
]
149142

150143

144+
@dataclass(frozen=True)
145+
class SingleUnresolvedFuture:
146+
"""A single leaf handle."""
147+
148+
handle: int
149+
150+
151+
@dataclass(frozen=True)
152+
class FirstCompletedUnresolvedFuture:
153+
"""first child to complete (success or failure) wins."""
154+
155+
children: List["UnresolvedFuture"] = field(default_factory=list)
156+
157+
158+
@dataclass(frozen=True)
159+
class AllCompletedUnresolvedFuture:
160+
"""wait for all children to complete."""
161+
162+
children: List["UnresolvedFuture"] = field(default_factory=list)
163+
164+
165+
UnresolvedFuture = Union[
166+
SingleUnresolvedFuture,
167+
FirstCompletedUnresolvedFuture,
168+
AllCompletedUnresolvedFuture,
169+
]
170+
171+
172+
def _unresolved_future_to_pyo3(uf: UnresolvedFuture) -> PyUnresolvedFuture:
173+
"""Recursively convert a Python-side UnresolvedFuture dataclass to its PyO3 pyclass."""
174+
if isinstance(uf, SingleUnresolvedFuture):
175+
return PyUnresolvedFuture.single(uf.handle)
176+
if isinstance(uf, FirstCompletedUnresolvedFuture):
177+
return PyUnresolvedFuture.first_completed([_unresolved_future_to_pyo3(c) for c in uf.children])
178+
if isinstance(uf, AllCompletedUnresolvedFuture):
179+
return PyUnresolvedFuture.all_completed([_unresolved_future_to_pyo3(c) for c in uf.children])
180+
raise TypeError(f"Unknown UnresolvedFuture variant: {type(uf).__name__}")
181+
182+
151183
# pylint: disable=too-many-public-methods
152184
class VMWrapper:
153185
"""
@@ -195,24 +227,22 @@ def is_completed(self, handle: int) -> bool:
195227
return self.vm.is_completed(handle)
196228

197229
# pylint: disable=R0911
198-
def do_progress(self, handles: list[int]) -> typing.Union[DoProgressResult, Exception, Suspended]:
230+
def do_progress(self, unresolved_future: UnresolvedFuture) -> typing.Union[DoProgressResult, Exception, Suspended]:
199231
"""Do progress with notifications."""
200232
try:
201-
result = self.vm.do_progress(handles)
233+
result = self.vm.do_progress(_unresolved_future_to_pyo3(unresolved_future))
202234
except VMException as e:
203235
return e
204236
if isinstance(result, PySuspended):
205237
return SUSPENDED
206238
if isinstance(result, PyDoProgressAnyCompleted):
207239
return DO_PROGRESS_ANY_COMPLETED
208-
if isinstance(result, PyDoProgressReadFromInput):
209-
return DO_PROGRESS_READ_FROM_INPUT
240+
if isinstance(result, PyDoProgressWaitExternalProgress):
241+
return DO_PROGRESS_WAIT_EXTERNAL_PROGRESS
210242
if isinstance(result, PyDoProgressExecuteRun):
211243
return DoProgressExecuteRun(result.handle)
212244
if isinstance(result, PyDoProgressCancelSignalReceived):
213245
return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED
214-
if isinstance(result, PyDoWaitForPendingRun):
215-
return DO_WAIT_PENDING_RUN
216246
return ValueError(f"Unknown progress type: {result}")
217247

218248
def take_notification(self, handle: int) -> typing.Union[NotificationType, Exception, Suspended]:
@@ -343,9 +373,8 @@ def sys_call(
343373
headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None,
344374
):
345375
"""Call a service"""
346-
if headers:
347-
headers = [PyHeader(key=h[0], value=h[1]) for h in headers]
348-
return self.vm.sys_call(service, handler, parameter, key, idempotency_key, headers)
376+
py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None
377+
return self.vm.sys_call(service, handler, parameter, key, idempotency_key, py_headers)
349378

350379
# pylint: disable=too-many-arguments
351380
def sys_send(
@@ -362,9 +391,8 @@ def sys_send(
362391
send an invocation to a service, and return the handle
363392
to the promise that will resolve with the invocation id
364393
"""
365-
if headers:
366-
headers = [PyHeader(key=h[0], value=h[1]) for h in headers]
367-
return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, headers)
394+
py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None
395+
return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, py_headers)
368396

369397
def sys_run(self, name: str) -> int:
370398
"""
@@ -391,17 +419,11 @@ def sys_reject_awakeable(self, name: str, failure: Failure):
391419
py_failure = PyFailure(failure.code, failure.message)
392420
self.vm.sys_complete_awakeable_failure(name, py_failure)
393421

394-
def propose_run_completion_success(self, handle: int, output: bytes) -> int:
422+
def propose_run_completion_success(self, handle: int, output: bytes) -> None:
395423
"""
396-
Exit a side effect
397-
398-
Args:
399-
output: The output of the side effect.
400-
401-
Returns:
402-
handle
424+
Exit a side effect with a success value.
403425
"""
404-
return self.vm.propose_run_completion_success(handle, output)
426+
self.vm.propose_run_completion_success(handle, output)
405427

406428
def sys_get_promise(self, name: str) -> int:
407429
"""Returns the promise handle"""
@@ -420,16 +442,12 @@ def sys_complete_promise_failure(self, name: str, failure: Failure) -> int:
420442
res = PyFailure(failure.code, failure.message)
421443
return self.vm.sys_complete_promise_failure(name, res)
422444

423-
def propose_run_completion_failure(self, handle: int, output: Failure) -> int:
445+
def propose_run_completion_failure(self, handle: int, output: Failure) -> None:
424446
"""
425-
Exit a side effect
426-
427-
Args:
428-
name: The name of the side effect.
429-
output: The output of the side effect.
447+
Exit a side effect with a terminal failure.
430448
"""
431449
res = PyFailure(output.code, output.message)
432-
return self.vm.propose_run_completion_failure(handle, res)
450+
self.vm.propose_run_completion_failure(handle, res)
433451

434452
# pylint: disable=line-too-long
435453
def propose_run_completion_transient(

0 commit comments

Comments
 (0)