Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion plugboard-schemas/plugboard_schemas/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from ._validator_registry import validator


_SYSTEM_STOP_EVENT = "system_stop"


def _build_component_graph(
connectors: dict[str, dict[str, _t.Any]],
) -> dict[str, set[str]]:
Expand Down Expand Up @@ -98,9 +101,11 @@ def validate_all_inputs_connected(
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
all_inputs = set(io.get("inputs", []))
input_events = set(io.get("input_events", []))
has_non_system_input_events = bool(input_events - {_SYSTEM_STOP_EVENT})
connected = connected_inputs.get(comp_name, set())
unconnected = all_inputs - connected
if unconnected:
if unconnected and not has_non_system_input_events:
errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}")
return errors

Expand Down
14 changes: 12 additions & 2 deletions plugboard/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ async def _wrapper() -> None:
with self._job_id_ctx():
await self._set_status(Status.RUNNING, publish=not self._is_running)
await self._io_read_with_status_check()
# IO can close here once all producers for the component's event-only inputs have
# finished emitting. Return before rebinding inputs so the last event-populated
# field values are not replayed as if they were fresh inputs in another step.
if self.io.is_closed:
return
await self._handle_events()
self._bind_inputs()
if self._can_step:
Expand All @@ -365,6 +370,11 @@ async def _wrapper() -> None:
def _has_field_inputs(self) -> bool:
return len(self.io.inputs) > 0

@property
def _has_connected_field_inputs(self) -> bool:
"""Whether any declared field inputs are connected via input channels."""
return self.io.has_connected_field_inputs

@cached_property
def _has_event_inputs(self) -> bool:
input_events = set([evt.safe_type() for evt in self.io.input_events])
Expand Down Expand Up @@ -409,7 +419,7 @@ async def _io_read_with_status_check(self) -> None:
task.cancel()
for task in done:
exc = task.exception()
if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0:
if isinstance(exc, EventStreamClosedError) and not self._has_connected_field_inputs:
await self.io.close() # Call close for final wait and flush event buffer
elif exc is not None:
raise exc
Expand All @@ -422,7 +432,7 @@ async def _periodic_status_check(self) -> None:
# TODO : Eventually producer graph update will be event driven. For now,
# : the update is performed periodically, so it's called here along
# : with the status check.
if len(self.io.inputs) == 0:
if not self._has_connected_field_inputs:
await self._update_producer_graph()

async def _status_check(self) -> None:
Expand Down
11 changes: 6 additions & 5 deletions plugboard/component/io_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def is_closed(self) -> bool:
"""Returns `True` if the `IOController` is closed, `False` otherwise."""
return self._is_closed

@cached_property
def _has_field_inputs(self) -> bool:
@property
def has_connected_field_inputs(self) -> bool:
"""Returns whether any field inputs are connected via channels."""
return len(self._input_channels) > 0

@cached_property
Expand All @@ -96,7 +97,7 @@ def _has_event_inputs(self) -> bool:

@cached_property
def _has_inputs(self) -> bool:
return self._has_field_inputs or self._has_event_inputs
return self.has_connected_field_inputs or self._has_event_inputs

async def read(self, timeout: float | None = None) -> None:
"""Reads data and/or events from input channels.
Expand Down Expand Up @@ -139,7 +140,7 @@ async def read(self, timeout: float | None = None) -> None:

def _set_read_tasks(self) -> list[asyncio.Task]:
read_tasks: list[asyncio.Task] = []
if self._has_field_inputs:
if self.has_connected_field_inputs:
if _fields_read_task not in self._read_tasks:
read_fields_task = asyncio.create_task(self._read_fields(), name=_fields_read_task)
self._read_tasks[_fields_read_task] = read_fields_task
Expand Down Expand Up @@ -374,7 +375,7 @@ def _add_channel_for_event(

def _create_input_field_group_tasks(self) -> None:
"""Groups input field channels by field name and launches read tasks for group inputs."""
if not self._has_field_inputs:
if not self.has_connected_field_inputs:
return
field_channels: dict[str, list[tuple[_t_field_key, Channel]]] = defaultdict(list)
for key, chan in self._input_channels.items():
Expand Down
80 changes: 80 additions & 0 deletions tests/integration/test_process_with_components_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from plugboard.events import Event
from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError
from plugboard.library import FileWriter
from plugboard.process import LocalProcess, Process, RayProcess
from plugboard.schemas import ConnectorSpec, Status
from tests.conftest import ComponentTestHelper, zmq_connector_cls
Expand Down Expand Up @@ -459,6 +460,85 @@ async def test_event_driven_process_shutdown(
await process.destroy()


class MessageEventData(BaseModel):
"""Data for a message event."""

message: str


class MessageEvent(Event):
"""Event carrying a file-writer message."""

type: _t.ClassVar[str] = "message_event"
data: MessageEventData


class MessageEventGenerator(ComponentTestHelper):
"""Produces a fixed number of message events."""

io = IO(output_events=[MessageEvent])

def __init__(self, iters: int, *args: _t.Any, **kwargs: _t.Any) -> None:
super().__init__(*args, **kwargs)
self._iters = iters

async def init(self) -> None:
await super().init()
self._seq = iter(range(self._iters))

async def step(self) -> None:
try:
idx = next(self._seq)
except StopIteration:
await self.io.close()
else:
evt = MessageEvent(
source=self.name,
data=MessageEventData(message=f"Message {idx}"),
)
self.io.queue_event(evt)
await super().step()


class EventReaderFileWriter(FileWriter):
"""`FileWriter` variant that adds event handling instead of a connector for `message`."""

io = IO(input_events=[MessageEvent])

@MessageEvent.handler
async def handle_message(self, event: MessageEvent) -> None:
self.message = event.data.message


@pytest.mark.asyncio
async def test_event_driven_file_writer_reuse(tmp_path: Path) -> None:
"""Test that field-input components can be reused in event-driven processes."""
output_path = tmp_path / "output_messages.csv"
components = [
MessageEventGenerator(iters=3, name="message_event_generator"),
EventReaderFileWriter(
path=output_path,
name="event_reader_file_writer",
field_names=["message"],
),
]
event_connectors = AsyncioConnector.builder().build_event_connectors(components)
process = LocalProcess(components=components, connectors=event_connectors)

await process.init()
await process.run()

assert process.status == Status.COMPLETED
assert output_path.read_text().splitlines() == [
"message",
"Message 0",
"Message 1",
"Message 2",
]

await process.destroy()


_SHORT_TIMEOUT = 0.1


Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_process_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,21 @@ def test_no_inputs_no_errors(self) -> None:
errors = validate_all_inputs_connected(pd)
assert errors == []

def test_missing_inputs_allowed_for_event_driven_component_reuse(self) -> None:
"""Unconnected inputs are allowed when non-system input events can populate them."""
pd = _make_process_dict(
components={
"producer": _make_component("producer", output_events=["message_event"]),
"writer": _make_component(
"writer",
inputs=["message"],
input_events=["system_stop", "message_event"],
),
},
)
errors = validate_all_inputs_connected(pd)
assert errors == []


# ---------------------------------------------------------------------------
# Tests for validate_input_events
Expand Down
Loading