1515# pylint: disable=E1101,R0917
1616# pylint: disable=too-many-arguments
1717# pylint: disable=too-few-public-methods
18- from typing import Optional
18+ from typing import List , Optional , Union
1919from datetime import timedelta
2020
21- from dataclasses import dataclass
21+ from dataclasses import dataclass , field
2222import typing
2323from restate ._internal import (
2424 PyVM ,
3030 PyStateKeys ,
3131 PyExponentialRetryConfig ,
3232 PyDoProgressAnyCompleted ,
33- PyDoProgressReadFromInput ,
33+ PyDoProgressWaitExternalProgress ,
3434 PyDoProgressExecuteRun ,
35- PyDoWaitForPendingRun ,
3635 PyDoProgressCancelSignalReceived ,
36+ PyUnresolvedFuture ,
3737 CANCEL_NOTIFICATION_HANDLE ,
3838) # pylint: disable=import-error,no-name-in-module,line-too-long
3939
@@ -105,9 +105,10 @@ class DoProgressAnyCompleted:
105105 """
106106
107107
108- class DoProgressReadFromInput :
108+ class DoProgressWaitExternalProgress :
109109 """
110- Represents a notification that the input needs to be read.
110+ Represents a notification that external progress is required
111+ (either new input from the server or a pending run proposal).
111112 """
112113
113114
@@ -128,26 +129,57 @@ class DoProgressCancelSignalReceived:
128129 """
129130
130131
131- class DoWaitPendingRun :
132- """
133- Represents a notification that a run is pending
134- """
135-
136-
137132DO_PROGRESS_ANY_COMPLETED = DoProgressAnyCompleted ()
138- DO_PROGRESS_READ_FROM_INPUT = DoProgressReadFromInput ()
133+ DO_PROGRESS_WAIT_EXTERNAL_PROGRESS = DoProgressWaitExternalProgress ()
139134DO_PROGRESS_CANCEL_SIGNAL_RECEIVED = DoProgressCancelSignalReceived ()
140- DO_WAIT_PENDING_RUN = DoWaitPendingRun ()
141135
142136DoProgressResult = typing .Union [
143137 DoProgressAnyCompleted ,
144- DoProgressReadFromInput ,
138+ DoProgressWaitExternalProgress ,
145139 DoProgressExecuteRun ,
146140 DoProgressCancelSignalReceived ,
147- DoWaitPendingRun ,
148141]
149142
150143
144+ @dataclass (frozen = True )
145+ class SingleUnresolvedFuture :
146+ """A single leaf handle."""
147+
148+ handle : int
149+
150+
151+ @dataclass (frozen = True )
152+ class FirstCompletedUnresolvedFuture :
153+ """first child to complete (success or failure) wins."""
154+
155+ children : List ["UnresolvedFuture" ] = field (default_factory = list )
156+
157+
158+ @dataclass (frozen = True )
159+ class AllCompletedUnresolvedFuture :
160+ """wait for all children to complete."""
161+
162+ children : List ["UnresolvedFuture" ] = field (default_factory = list )
163+
164+
165+ UnresolvedFuture = Union [
166+ SingleUnresolvedFuture ,
167+ FirstCompletedUnresolvedFuture ,
168+ AllCompletedUnresolvedFuture ,
169+ ]
170+
171+
172+ def _unresolved_future_to_pyo3 (uf : UnresolvedFuture ) -> PyUnresolvedFuture :
173+ """Recursively convert a Python-side UnresolvedFuture dataclass to its PyO3 pyclass."""
174+ if isinstance (uf , SingleUnresolvedFuture ):
175+ return PyUnresolvedFuture .single (uf .handle )
176+ if isinstance (uf , FirstCompletedUnresolvedFuture ):
177+ return PyUnresolvedFuture .first_completed ([_unresolved_future_to_pyo3 (c ) for c in uf .children ])
178+ if isinstance (uf , AllCompletedUnresolvedFuture ):
179+ return PyUnresolvedFuture .all_completed ([_unresolved_future_to_pyo3 (c ) for c in uf .children ])
180+ raise TypeError (f"Unknown UnresolvedFuture variant: { type (uf ).__name__ } " )
181+
182+
151183# pylint: disable=too-many-public-methods
152184class VMWrapper :
153185 """
@@ -195,24 +227,22 @@ def is_completed(self, handle: int) -> bool:
195227 return self .vm .is_completed (handle )
196228
197229 # pylint: disable=R0911
198- def do_progress (self , handles : list [ int ] ) -> typing .Union [DoProgressResult , Exception , Suspended ]:
230+ def do_progress (self , unresolved_future : UnresolvedFuture ) -> typing .Union [DoProgressResult , Exception , Suspended ]:
199231 """Do progress with notifications."""
200232 try :
201- result = self .vm .do_progress (handles )
233+ result = self .vm .do_progress (_unresolved_future_to_pyo3 ( unresolved_future ) )
202234 except VMException as e :
203235 return e
204236 if isinstance (result , PySuspended ):
205237 return SUSPENDED
206238 if isinstance (result , PyDoProgressAnyCompleted ):
207239 return DO_PROGRESS_ANY_COMPLETED
208- if isinstance (result , PyDoProgressReadFromInput ):
209- return DO_PROGRESS_READ_FROM_INPUT
240+ if isinstance (result , PyDoProgressWaitExternalProgress ):
241+ return DO_PROGRESS_WAIT_EXTERNAL_PROGRESS
210242 if isinstance (result , PyDoProgressExecuteRun ):
211243 return DoProgressExecuteRun (result .handle )
212244 if isinstance (result , PyDoProgressCancelSignalReceived ):
213245 return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED
214- if isinstance (result , PyDoWaitForPendingRun ):
215- return DO_WAIT_PENDING_RUN
216246 return ValueError (f"Unknown progress type: { result } " )
217247
218248 def take_notification (self , handle : int ) -> typing .Union [NotificationType , Exception , Suspended ]:
@@ -343,9 +373,8 @@ def sys_call(
343373 headers : typing .Optional [typing .List [typing .Tuple [str , str ]]] = None ,
344374 ):
345375 """Call a service"""
346- if headers :
347- headers = [PyHeader (key = h [0 ], value = h [1 ]) for h in headers ]
348- return self .vm .sys_call (service , handler , parameter , key , idempotency_key , headers )
376+ py_headers = [PyHeader (key = h [0 ], value = h [1 ]) for h in headers ] if headers else None
377+ return self .vm .sys_call (service , handler , parameter , key , idempotency_key , py_headers )
349378
350379 # pylint: disable=too-many-arguments
351380 def sys_send (
@@ -362,9 +391,8 @@ def sys_send(
362391 send an invocation to a service, and return the handle
363392 to the promise that will resolve with the invocation id
364393 """
365- if headers :
366- headers = [PyHeader (key = h [0 ], value = h [1 ]) for h in headers ]
367- return self .vm .sys_send (service , handler , parameter , key , delay , idempotency_key , headers )
394+ py_headers = [PyHeader (key = h [0 ], value = h [1 ]) for h in headers ] if headers else None
395+ return self .vm .sys_send (service , handler , parameter , key , delay , idempotency_key , py_headers )
368396
369397 def sys_run (self , name : str ) -> int :
370398 """
@@ -391,17 +419,11 @@ def sys_reject_awakeable(self, name: str, failure: Failure):
391419 py_failure = PyFailure (failure .code , failure .message )
392420 self .vm .sys_complete_awakeable_failure (name , py_failure )
393421
394- def propose_run_completion_success (self , handle : int , output : bytes ) -> int :
422+ def propose_run_completion_success (self , handle : int , output : bytes ) -> None :
395423 """
396- Exit a side effect
397-
398- Args:
399- output: The output of the side effect.
400-
401- Returns:
402- handle
424+ Exit a side effect with a success value.
403425 """
404- return self .vm .propose_run_completion_success (handle , output )
426+ self .vm .propose_run_completion_success (handle , output )
405427
406428 def sys_get_promise (self , name : str ) -> int :
407429 """Returns the promise handle"""
@@ -420,16 +442,12 @@ def sys_complete_promise_failure(self, name: str, failure: Failure) -> int:
420442 res = PyFailure (failure .code , failure .message )
421443 return self .vm .sys_complete_promise_failure (name , res )
422444
423- def propose_run_completion_failure (self , handle : int , output : Failure ) -> int :
445+ def propose_run_completion_failure (self , handle : int , output : Failure ) -> None :
424446 """
425- Exit a side effect
426-
427- Args:
428- name: The name of the side effect.
429- output: The output of the side effect.
447+ Exit a side effect with a terminal failure.
430448 """
431449 res = PyFailure (output .code , output .message )
432- return self .vm .propose_run_completion_failure (handle , res )
450+ self .vm .propose_run_completion_failure (handle , res )
433451
434452 # pylint: disable=line-too-long
435453 def propose_run_completion_transient (
0 commit comments