From 4288aabec3188888b58d025b4f36c6da35b3f0b6 Mon Sep 17 00:00:00 2001 From: Dmitrii Amelin Date: Mon, 13 Apr 2026 01:28:05 +0200 Subject: [PATCH] Decorator Manager; see #795 --- README.md | 4 +- custom_components/pyscript/__init__.py | 9 +- custom_components/pyscript/config_flow.py | 11 +- custom_components/pyscript/const.py | 1 + custom_components/pyscript/decorator.py | 283 +++++++ custom_components/pyscript/decorator_abc.py | 283 +++++++ .../pyscript/decorators/__init__.py | 21 + custom_components/pyscript/decorators/base.py | 59 ++ .../pyscript/decorators/event.py | 59 ++ custom_components/pyscript/decorators/mqtt.py | 69 ++ .../pyscript/decorators/service.py | 140 ++++ .../pyscript/decorators/state.py | 321 ++++++++ custom_components/pyscript/decorators/task.py | 40 + .../pyscript/decorators/timing.py | 163 ++++ .../pyscript/decorators/webhook.py | 84 ++ custom_components/pyscript/eval.py | 15 +- custom_components/pyscript/global_ctx.py | 39 +- custom_components/pyscript/strings.json | 6 +- .../pyscript/translations/de.json | 6 +- .../pyscript/translations/en.json | 6 +- .../pyscript/translations/sk.json | 6 +- .../pyscript/translations/tr.json | 6 +- docs/configuration.rst | 12 +- docs/reference.rst | 17 +- tests/test_apps_modules.py | 3 + tests/test_config_flow.py | 99 ++- tests/test_decorator_errors.py | 5 +- tests/test_decorator_manager.py | 750 ++++++++++++++++++ tests/test_init.py | 5 +- tests/test_unit_eval.py | 2 + 30 files changed, 2480 insertions(+), 44 deletions(-) create mode 100644 custom_components/pyscript/decorator.py create mode 100644 custom_components/pyscript/decorator_abc.py create mode 100644 custom_components/pyscript/decorators/__init__.py create mode 100644 custom_components/pyscript/decorators/base.py create mode 100644 custom_components/pyscript/decorators/event.py create mode 100644 custom_components/pyscript/decorators/mqtt.py create mode 100644 custom_components/pyscript/decorators/service.py create mode 100644 custom_components/pyscript/decorators/state.py create mode 100644 custom_components/pyscript/decorators/task.py create mode 100644 custom_components/pyscript/decorators/timing.py create mode 100644 custom_components/pyscript/decorators/webhook.py create mode 100644 tests/test_decorator_manager.py diff --git a/README.md b/README.md index a562a65e..5f01d952 100644 --- a/README.md +++ b/README.md @@ -62,12 +62,14 @@ this [README](https://github.com/craigbarratt/hass-pyscript-jupyter/blob/master/ ## Configuration -* Go to the Integrations menu in the Home Assistant Configuration UI and add `Pyscript Python scripting` from there. Alternatively, add `pyscript:` to `/configuration.yaml`; pyscript has two optional configuration parameters that allow any python package to be imported if set and to expose `hass` as a variable; both default to `false`: +* Go to the Integrations menu in the Home Assistant Configuration UI and add `Pyscript Python scripting` from there. Alternatively, add `pyscript:` to `/configuration.yaml`; pyscript has three optional configuration parameters that allow any python package to be imported if set, expose `hass` as a variable, and temporarily switch back to the legacy decorator subsystem; all three default to `false`: ```yaml pyscript: allow_all_imports: true hass_is_global: true + legacy_decorators: true ``` + Starting with version `2.0.0`, pyscript uses the new decorator subsystem by default. If you find a problem in the new implementation, you can temporarily set `legacy_decorators: true` to switch back to the legacy one. If you do, please also file a bug report in [GitHub Issues](https://github.com/custom-components/pyscript/issues) so the new subsystem can be fixed. * Add files with a suffix of `.py` in the folder `/pyscript`. * Restart HASS. * Whenever you change a script file, make a `reload` service call to `pyscript`. diff --git a/custom_components/pyscript/__init__.py b/custom_components/pyscript/__init__.py index 396b26cd..34a6488c 100644 --- a/custom_components/pyscript/__init__.py +++ b/custom_components/pyscript/__init__.py @@ -33,6 +33,7 @@ from .const import ( CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, + CONF_LEGACY_DECORATORS, CONFIG_ENTRY, CONFIG_ENTRY_OLD, DOMAIN, @@ -44,6 +45,7 @@ UNSUB_LISTENERS, WATCHDOG_TASK, ) +from .decorator import DecoratorRegistry from .eval import AstEval from .event import Event from .function import Function @@ -62,6 +64,7 @@ { vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): cv.boolean, vol.Optional(CONF_HASS_IS_GLOBAL, default=False): cv.boolean, + vol.Optional(CONF_LEGACY_DECORATORS, default=False): cv.boolean, }, extra=vol.ALLOW_EXTRA, ) @@ -114,14 +117,15 @@ async def update_yaml_config(hass: HomeAssistant, config_entry: ConfigEntry) -> # since they affect all scripts # config_save = { - param: config_entry.data.get(param, False) for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS] + param: config_entry.data.get(param, False) + for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS, CONF_LEGACY_DECORATORS] } if DOMAIN not in hass.data: hass.data.setdefault(DOMAIN, {}) if CONFIG_ENTRY_OLD in hass.data[DOMAIN]: old_entry = hass.data[DOMAIN][CONFIG_ENTRY_OLD] hass.data[DOMAIN][CONFIG_ENTRY_OLD] = config_save - for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS]: + for param in [CONF_HASS_IS_GLOBAL, CONF_ALLOW_ALL_IMPORTS, CONF_LEGACY_DECORATORS]: if old_entry.get(param, False) != config_entry.data.get(param, False): return True hass.data[DOMAIN][CONFIG_ENTRY_OLD] = config_save @@ -272,6 +276,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b Webhook.init(hass) State.register_functions() GlobalContextMgr.init() + DecoratorRegistry.init(hass, config_entry) pyscript_folder = hass.config.path(FOLDER) if not await hass.async_add_executor_job(os.path.isdir, pyscript_folder): diff --git a/custom_components/pyscript/config_flow.py b/custom_components/pyscript/config_flow.py index 1b8bc754..1330662a 100644 --- a/custom_components/pyscript/config_flow.py +++ b/custom_components/pyscript/config_flow.py @@ -9,14 +9,21 @@ from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.core import callback -from .const import CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, CONF_INSTALLED_PACKAGES, DOMAIN +from .const import ( + CONF_ALLOW_ALL_IMPORTS, + CONF_HASS_IS_GLOBAL, + CONF_INSTALLED_PACKAGES, + CONF_LEGACY_DECORATORS, + DOMAIN, +) -CONF_BOOL_ALL = {CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL} +CONF_BOOL_ALL = (CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, CONF_LEGACY_DECORATORS) PYSCRIPT_SCHEMA = vol.Schema( { vol.Optional(CONF_ALLOW_ALL_IMPORTS, default=False): bool, vol.Optional(CONF_HASS_IS_GLOBAL, default=False): bool, + vol.Optional(CONF_LEGACY_DECORATORS, default=False): bool, }, extra=vol.ALLOW_EXTRA, ) diff --git a/custom_components/pyscript/const.py b/custom_components/pyscript/const.py index d4d89814..4cefdfd1 100644 --- a/custom_components/pyscript/const.py +++ b/custom_components/pyscript/const.py @@ -17,6 +17,7 @@ CONF_ALLOW_ALL_IMPORTS = "allow_all_imports" CONF_HASS_IS_GLOBAL = "hass_is_global" CONF_INSTALLED_PACKAGES = "_installed_packages" +CONF_LEGACY_DECORATORS = "legacy_decorators" SERVICE_JUPYTER_KERNEL_START = "jupyter_kernel_start" SERVICE_GENERATE_STUBS = "generate_stubs" diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py new file mode 100644 index 00000000..44518895 --- /dev/null +++ b/custom_components/pyscript/decorator.py @@ -0,0 +1,283 @@ +"""Decorator registry and manager logic for pyscript decorators.""" + +from __future__ import annotations + +import ast +import asyncio +import logging +import os +from typing import Any, ClassVar +import weakref + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import Context, HomeAssistant + +from .const import CONF_LEGACY_DECORATORS +from .decorator_abc import ( + CallHandlerDecorator, + CallResultHandlerDecorator, + Decorator, + DecoratorManager, + DecoratorManagerStatus, + DispatchData, + TriggerDecorator, + TriggerHandlerDecorator, +) +from .eval import AstEval, EvalFunc, EvalFuncVar +from .function import Function +from .state import State + +_LOGGER = logging.getLogger(__name__) + + +class DecoratorRegistry: + """Decorator registry.""" + + _decorators: dict[str, type[Decorator]] # decorator name to class + hass: ClassVar[HomeAssistant] + + @classmethod + def init(cls, hass: HomeAssistant, config_entry: ConfigEntry = None) -> None: + """Initialize the decorator registry.""" + cls.hass = hass + cls._decorators = {} + disabled = False + if config_entry is not None and config_entry.data.get(CONF_LEGACY_DECORATORS, False): + disabled = True + elif "PYTEST_CURRENT_TEST" in os.environ and "NODM" in os.environ: + disabled = True + + if disabled: + _LOGGER.warning("Using legacy decorators") + return + + DecoratorManager.hass = hass + + Function.register_ast({"task.wait_until": DecoratorRegistry.wait_until_factory}) + + from .decorators import DECORATORS # pylint: disable=import-outside-toplevel + + for dec_type in DECORATORS: + cls.register(dec_type) + + @classmethod + def register(cls, dec_type: type[Decorator]) -> None: + """Register a decorator.""" + if not dec_type.name: + raise TypeError(f"Decorator name is required {dec_type}") + + _LOGGER.debug("Registering decorator @%s %s", dec_type.name, dec_type) + if dec_type.name in cls._decorators: + _LOGGER.warning( + "Overriding decorator: %s %s with %s", + dec_type.name, + cls._decorators[dec_type.name], + dec_type, + ) + cls._decorators[dec_type.name] = dec_type + + @classmethod + async def get_decorator_by_expr(cls, ast_ctx: AstEval, dec_expr: ast.expr) -> Decorator | None: + """Return decorator instance from an AST decorator expression.""" + dec_name = None + has_args = False + + if isinstance(dec_expr, ast.Name): # decorator without () + dec_name = dec_expr.id + elif isinstance(dec_expr, ast.Call) and isinstance(dec_expr.func, ast.Name): + dec_name = dec_expr.func.id + has_args = True + + if know_decorator := cls._decorators.get(dec_name): + if has_args: + args = await ast_ctx.eval_elt_list(dec_expr.args) + kwargs = {keyw.arg: await ast_ctx.aeval(keyw.value) for keyw in dec_expr.keywords} + else: + args = [] + kwargs = {} + + decorator = know_decorator(args, kwargs) + return decorator + + return None + + @classmethod + async def wait_until(cls, ast_ctx: AstEval, *_arg: Any, **kwargs: Any) -> Any: + """Build a temporary decorator manager that waits until one of trigger decorators fires.""" + func_args = set(kwargs.keys()) + if len(func_args) == 0: + return {"trigger_type": "none"} + + found_args = set() + dm = WaitUntilDecoratorManager(ast_ctx, **kwargs) + + found_args.add("timeout") + found_args.add("__test_handshake__") + + for dec_name, dec_class in cls._decorators.items(): + if not issubclass(dec_class, TriggerDecorator): + continue + if dec_name not in func_args: + continue + + dec_args = kwargs[dec_name] + if not isinstance(dec_args, list): + dec_args = [dec_args] + found_args.add(dec_name) + + dec_kwargs = {} + func_args.remove(dec_name) + kwargs_schema_keys = dec_class.kwargs_schema.schema.keys() + for key in kwargs_schema_keys: + if key in kwargs: + dec_kwargs[key] = kwargs[key] + found_args.add(key) + dec = dec_class(dec_args, dec_kwargs) + dm.add(dec) + + unknown_args = set(kwargs.keys()).difference(found_args) + if unknown_args: + raise ValueError(f"Unknown arguments: {unknown_args}") + await dm.validate() + + # state_trigger sets __test_handshake__ after the initial checks. + # In some cases, it returns a value before __test_handshake__ is set. + if "state_trigger" not in kwargs: + if test_handshake := kwargs.get("__test_handshake__"): + # + # used for testing to avoid race conditions + # we use this as a handshake that we are about to + # listen to the queue + # + State.set(test_handshake[0], test_handshake[1]) + await dm.start() + + ret = await dm.wait_until() + + return ret + + @classmethod + def wait_until_factory(cls, ast_ctx): + """Return wrapper to call to astFunction with the ast context.""" + + async def wait_until_call(*arg, **kw): + return await cls.wait_until(ast_ctx, *arg, **kw) + + return wait_until_call + + +class WaitUntilDecoratorManager(DecoratorManager): + """Decorator manager for task.wait_until.""" + + def __init__(self, ast_ctx: AstEval, **kwargs: dict[str, Any]) -> None: + """Initialize the task.wait_until decorator manager.""" + super().__init__(ast_ctx, ast_ctx.name) + self.kwargs = kwargs + self._future: asyncio.Future[DispatchData] = self.hass.loop.create_future() + self.timeout_decorator = None + if timeout := kwargs.get("timeout"): + to_dec = DecoratorRegistry._decorators.get("time_trigger") + self.timeout_decorator = to_dec([f"once(now + {timeout}s)"], {}) + self.add(self.timeout_decorator) + + async def dispatch(self, data: DispatchData) -> None: + """Resolve the waiting future on the first incoming dispatch.""" + _LOGGER.debug("task.wait_until dispatch: %s", data) + if self._future.done(): + _LOGGER.debug("task.wait_until future already completed: %s", self._future.exception()) + # ignore another calls + return + await self.stop() + self._future.set_result(data) + + async def handle_exception(self, exc: Exception) -> None: + """Propagate an evaluation exception to the waiting caller.""" + if self._future.done(): + _LOGGER.debug("task.wait_until future already completed: %s", self._future.exception()) + return + await self.stop() + self._future.set_exception(exc) + + async def wait_until(self) -> dict[str, Any]: + """Wait for dispatch and normalize the return payload.""" + data = await self._future + if data.trigger == self.timeout_decorator: + ret = {"trigger_type": "timeout"} + else: + ret = data.func_args + _LOGGER.debug("task.wait_until finish: %s", ret) + return ret + + +class FunctionDecoratorManager(DecoratorManager): + """Maintain and validate a set of decorators applied to a function.""" + + def __init__(self, ast_ctx: AstEval, eval_func_var: EvalFuncVar) -> None: + """Initialize the function decorator manager.""" + super().__init__(ast_ctx, f"{ast_ctx.get_global_ctx_name()}.{eval_func_var.get_name()}") + self.eval_func: EvalFunc = eval_func_var.func + + self.logger = self.eval_func.logger + + def on_func_var_deleted(): + if self.status is DecoratorManagerStatus.RUNNING: + self.hass.async_create_task(self.stop()) + + weakref.finalize(eval_func_var, on_func_var_deleted) + + async def _call(self, data: DispatchData) -> None: + handlers = self.get_decorators(CallHandlerDecorator) + result_handlers = self.get_decorators(CallResultHandlerDecorator) + + for handler_dec in handlers: + if await handler_dec.handle_call(data) is False: + self.logger.debug("Calling canceled by %s", handler_dec) + # notify handlers with "None" + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, None) + return + # Fire an event indicating that pyscript is running + # Note: the event must have an entity_id for logbook to work correctly. + ev_name = self.name.replace(".", "_") + ev_entity_id = f"pyscript.{ev_name}" + + event_data = {"name": ev_name, "entity_id": ev_entity_id, "func_args": data.func_args} + self.hass.bus.async_fire("pyscript_running", event_data, context=data.hass_context) + # Store HASS Context for this Task + Function.store_hass_context(data.hass_context) + + result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, result) + + async def dispatch(self, data: DispatchData) -> None: + """Handle a trigger dispatch: run guards, create a context, and invoke the function.""" + _LOGGER.debug("Dispatching for %s: %s", self.name, data) + + decorators = self.get_decorators(TriggerHandlerDecorator) + for dec in decorators: + if await dec.handle_dispatch(data) is False: + self.logger.debug("Trigger not active due to %s", dec) + return + + action_ast_ctx = AstEval( + f"{self.eval_func.global_ctx_name}.{self.eval_func.name}", self.eval_func.global_ctx + ) + Function.install_ast_funcs(action_ast_ctx) + data.call_ast_ctx = action_ast_ctx + + # Create new HASS Context with incoming as parent + if "context" in data.func_args and isinstance(data.func_args["context"], Context): + data.hass_context = Context(parent_id=data.func_args["context"].id) + else: + data.hass_context = Context() + + self.logger.debug( + "trigger %s got %s trigger, running action (kwargs = %s)", + self.name, + data.trigger, + data.func_args, + ) + + task = Function.create_task(self._call(data), ast_ctx=action_ast_ctx) + Function.task_done_callback_ctx(task, action_ast_ctx) diff --git a/custom_components/pyscript/decorator_abc.py b/custom_components/pyscript/decorator_abc.py new file mode 100644 index 00000000..e978e2c7 --- /dev/null +++ b/custom_components/pyscript/decorator_abc.py @@ -0,0 +1,283 @@ +"""Base abstractions for pyscript decorators and decorator managers.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import StrEnum +import logging +from typing import Any, ClassVar, final + +import voluptuous as vol + +from homeassistant.core import Context, HomeAssistant + +from . import trigger +from .eval import AstEval + +_LOGGER = logging.getLogger(__name__) + + +def dt_now(): + """Return current time.""" + # For test compatibility. The tests patch this function + return trigger.dt_now() + + +class DecoratorManagerStatus(StrEnum): + """Status of a decorator manager.""" + + INIT = "init" # initial status when created + NO_DECORATORS = "no_decorators" # no decorators found + VALIDATED = "validated" + INVALID = "invalid" + RUNNING = "running" + STOPPED = "stopped" + + +@dataclass() +class DispatchData: + """Data for a dispatch event.""" + + func_args: dict[str, Any] + trigger: TriggerDecorator | None = field(default=None, kw_only=True) + trigger_context: dict[str, Any] = field(default_factory=dict, kw_only=True) + + call_ast_ctx: AstEval | None = field(default=None, kw_only=True) + hass_context: Context | None = field(default=None, kw_only=True) + + +class Decorator(ABC): + """Generic decorator abstraction.""" + + # Subclasses should override. + name: ClassVar[str] = "" + # without args by default + args_schema: ClassVar[vol.Schema] = vol.Schema([], extra=vol.PREVENT_EXTRA) + # without kwargs by default + kwargs_schema: ClassVar[vol.Schema] = vol.Schema({}, extra=vol.PREVENT_EXTRA) + + # instance attributes + dm: DecoratorManager + raw_args: list[Any] + raw_kwargs: dict[str, Any] + + args: list[Any] + kwargs: dict[str, Any] + + @final + def __init__(self, raw_args: list[Any], raw_kwargs: dict[str, Any]) -> None: + """Initialize the decorator definition.""" + self.raw_args = raw_args + self.raw_kwargs = raw_kwargs + + async def validate(self) -> None: + """Validate the arguments.""" + + _LOGGER.debug("Validating %s", self.name) + + try: + self.args = self.args_schema(self.raw_args) + self.kwargs = self.kwargs_schema(self.raw_kwargs) + + except vol.Invalid as err: + # Keep this wording for transition compatibility. Once the legacy + # subsystem is removed, update the message and related tests. + if len(err.path) == 1: + if "extra keys not allowed" in err.msg: + message = f"invalid keyword argument '{err.path[0]}'" + else: + message = f"keyword '{err.path[0]}' {err}" + else: + message = str(err) + + type_error = TypeError( + f"function '{self.dm.func_name}' defined in {self.dm.ast_ctx.get_global_ctx_name()}: " + f"decorator @{self.name} {message}" + ) + raise type_error from err + + async def start(self): # noqa: B027 + """Start the decorator.""" + + async def stop(self): # noqa: B027 + """Stop the decorator.""" + + def __repr__(self): + """Represent the decorator as a string with the decorator name and arguments.""" + parts = [] + if self.raw_args is not None: + parts.append(",".join(map(str, self.raw_args))) + if self.raw_kwargs is not None: + parts += [f"{k}={v!r}" for k, v in self.raw_kwargs.items()] + return f"@{self.name}({', '.join(parts)})" + + +class DecoratorManager(ABC): + """Maintain and validate a set of decorators.""" + + hass: ClassVar[HomeAssistant] + + def __init__(self, ast_ctx: AstEval, name: str) -> None: + """Initialize the manager.""" + self.ast_ctx = ast_ctx + self.name = name + self.func_name = name.split(".")[-1] + self.logger = ast_ctx.get_logger() + + self.status: DecoratorManagerStatus = DecoratorManagerStatus.INIT + self.startup_time = None + self._decorators: list[Decorator] = [] + + def update_status(self, new_status: DecoratorManagerStatus) -> None: + """Update the manager status.""" + if self.status is new_status: + return + _LOGGER.debug("DM %s status: %s -> %s", self.name, self.status.value, new_status.value) + self.status = new_status + + if new_status in (DecoratorManagerStatus.STOPPED, DecoratorManagerStatus.INVALID): + del self._decorators[:] + + def add(self, decorator: Decorator) -> None: + """Add a decorator to the manager.""" + _LOGGER.debug("Add %s to %s", decorator, self) + self._decorators.append(decorator) + decorator.dm = self + + def get_decorators[DT](self, decorator_type: type[DT] | None = None) -> list[DT]: + """Get decorators of a specific type.""" + if decorator_type is None: + return self._decorators.copy() + return [dec for dec in self._decorators if isinstance(dec, decorator_type)] + + async def validate(self) -> None: + """Validate all decorators.""" + try: + for decorator in self._decorators: + _LOGGER.debug("Validating decorator: %s", decorator) + await decorator.validate() + except Exception: + self.update_status(DecoratorManagerStatus.INVALID) + raise + + if len(self._decorators) == 0: + self.update_status(DecoratorManagerStatus.NO_DECORATORS) + else: + self.update_status(DecoratorManagerStatus.VALIDATED) + + async def start(self): + """Start all decorators.""" + if self.status is not DecoratorManagerStatus.VALIDATED: + raise RuntimeError(f"Starting not valid {self}") + + self.startup_time = dt_now() + self.update_status(DecoratorManagerStatus.RUNNING) + started = [] + for decorator in self._decorators: + _LOGGER.debug("Starting decorator: %s", decorator) + try: + await decorator.start() + started.append(decorator) + except Exception as err: + self.logger.exception("%s start failed: %s", self, err) + for started_dec in started: + await self._stop_decorator(started_dec) + self.startup_time = None + self.update_status(DecoratorManagerStatus.INVALID) + raise + + async def _stop_decorator(self, decorator: Decorator) -> None: + try: + await decorator.stop() + except Exception as err: + _LOGGER.exception("%s stop failed: %s", self, err) + + async def stop(self): + """Stop all decorators.""" + if self.status is not DecoratorManagerStatus.RUNNING: + _LOGGER.warning("Stopping before starting for %s (status=%s)", self.name, self.status.value) + return + + _LOGGER.debug("Stopping all decorators %s", self) + for decorator in self._decorators: + await self._stop_decorator(decorator) + + self.update_status(DecoratorManagerStatus.STOPPED) + + async def handle_exception(self, exc: Exception) -> None: + """Handle a decorator exception.""" + self.ast_ctx.log_exception(exc) + + @abstractmethod + async def dispatch(self, data: DispatchData) -> None: + """Dispatch a trigger call.""" + + def __repr__(self): + """Return a string representation of the manager with status and decorators.""" + return f"{self.__class__.__name__}({self.status}) {self._decorators} for {self.name}()>" + + +class TriggerDecorator(Decorator, ABC): + """Base class for trigger-based decorators.""" + + def __init_subclass__(cls, **kwargs): + """Initialize the decorator class.""" + super().__init_subclass__(**kwargs) + # kwargs for all triggers + if "kwargs" not in cls.kwargs_schema.schema.keys(): + cls.kwargs_schema = cls.kwargs_schema.extend( + {vol.Optional("kwargs"): vol.Coerce(dict[str, Any], msg="should be type dict")} + ) + + async def dispatch(self, data: DispatchData) -> None: + """Dispatch a trigger call to the function.""" + if not data.trigger: + data.trigger = self + + data.func_args.update(self.kwargs.get("kwargs", {})) + + await self.dm.dispatch(data) + + +class TriggerHandlerDecorator(Decorator, ABC): + """Base class for trigger handler decorators.""" + + async def validate(self) -> None: + """Validate the decorated function.""" + await super().validate() + decorators = self.dm.get_decorators(TriggerDecorator) + if len(decorators) == 0: + # Keep this wording for transition compatibility. Once the legacy + # subsystem is removed, update the message and related tests. + trig_decorators_reqd = { + "event_trigger", + "mqtt_trigger", + "state_trigger", + "time_trigger", + "webhook_trigger", + } + raise ValueError( + f"{self.dm.func_name} defined in {self.dm.ast_ctx.get_global_ctx_name()}: " + f"needs at least one trigger decorator (ie: {', '.join(sorted(trig_decorators_reqd))})" + ) + + @abstractmethod + async def handle_dispatch(self, data: DispatchData) -> bool | None: + """Handle a trigger dispatch call. Return False for stop dispatching.""" + + +class CallHandlerDecorator(Decorator, ABC): + """Base class for call-based handlers.""" + + @abstractmethod + async def handle_call(self, data: DispatchData) -> bool | None: + """Handle an action call. Return False for stop calling.""" + + +class CallResultHandlerDecorator(Decorator, ABC): + """Base class for call-based result handlers.""" + + @abstractmethod + async def handle_call_result(self, data: DispatchData, result: Any) -> None: + """Handle an action call result.""" diff --git a/custom_components/pyscript/decorators/__init__.py b/custom_components/pyscript/decorators/__init__.py new file mode 100644 index 00000000..f21b9a9f --- /dev/null +++ b/custom_components/pyscript/decorators/__init__.py @@ -0,0 +1,21 @@ +"""Pyscript decorators.""" + +from .event import EventTriggerDecorator +from .mqtt import MQTTTriggerDecorator +from .service import ServiceDecorator +from .state import StateActiveDecorator, StateTriggerDecorator +from .task import TaskUniqueDecorator +from .timing import TimeActiveDecorator, TimeTriggerDecorator +from .webhook import WebhookTriggerDecorator + +DECORATORS = [ + StateTriggerDecorator, + StateActiveDecorator, + TimeTriggerDecorator, + TimeActiveDecorator, + TaskUniqueDecorator, + EventTriggerDecorator, + MQTTTriggerDecorator, + WebhookTriggerDecorator, + ServiceDecorator, +] diff --git a/custom_components/pyscript/decorators/base.py b/custom_components/pyscript/decorators/base.py new file mode 100644 index 00000000..213a67e5 --- /dev/null +++ b/custom_components/pyscript/decorators/base.py @@ -0,0 +1,59 @@ +"""Base mixins for pyscript decorators.""" + +from abc import ABC +import logging +from typing import Any + +import voluptuous as vol + +from ..decorator import FunctionDecoratorManager +from ..decorator_abc import Decorator +from ..eval import AstEval, Function + +_LOGGER = logging.getLogger(__name__) + + +class AutoKwargsDecorator(Decorator, ABC): + """Mixin that copies validated kwargs into instance attributes based on annotations.""" + + async def validate(self) -> None: + """Run base validation and materialize annotated kwargs as attributes.""" + await super().validate() + for k in self.__class__.kwargs_schema.schema: + if isinstance(k, vol.Marker): + k = k.schema + if k in self.__class__.__annotations__: + setattr(self, k, self.kwargs.get(k, None)) + + +class ExpressionDecorator(Decorator, ABC): + """Base for AstEval-based decorators.""" + + _ast_expression: AstEval = None + + def create_expression(self, expression: str) -> None: + """Create AstEval expression.""" + _LOGGER.debug("Create expression: %s, %s", expression, self) + dec_name = self.name + if isinstance(self.dm, FunctionDecoratorManager): + dec_name = "@" + dec_name + "()" + + self._ast_expression = AstEval( + self.dm.name + " " + dec_name, self.dm.ast_ctx.global_ctx, self.dm.name + ) + Function.install_ast_funcs(self._ast_expression) + self._ast_expression.parse(expression, mode="eval") + + def has_expression(self) -> bool: + """Return True if expression was created.""" + return self._ast_expression is not None + + async def check_expression_vars(self, state_vars: dict[str, Any]) -> bool: + """Evaluate expression and dispatch an exception event via manager on failure.""" + if not self.has_expression(): + raise AttributeError(f"{self} has no expression defined") + try: + return await self._ast_expression.eval(state_vars) + except Exception as exc: + await self.dm.handle_exception(exc) + return False diff --git a/custom_components/pyscript/decorators/event.py b/custom_components/pyscript/decorators/event.py new file mode 100644 index 00000000..82e422c6 --- /dev/null +++ b/custom_components/pyscript/decorators/event.py @@ -0,0 +1,59 @@ +"""Event decorator.""" + +import logging + +import voluptuous as vol + +from homeassistant.core import CALLBACK_TYPE, Event + +from ..decorator_abc import DispatchData, TriggerDecorator +from .base import ExpressionDecorator + +_LOGGER = logging.getLogger(__name__) + + +class EventTriggerDecorator(TriggerDecorator, ExpressionDecorator): + """Implementation for @event_trigger.""" + + name = "event_trigger" + args_schema = vol.Schema( + vol.All( + [vol.Coerce(str)], + vol.Length(min=1, max=2, msg="needs at least one argument"), + ) + ) + + remove_listener_callback: CALLBACK_TYPE | None = None + + async def validate(self) -> None: + """Validate the event trigger.""" + await super().validate() + if len(self.args) == 2: + self.create_expression(self.args[1]) + + async def _event_callback(self, event: Event) -> None: + _LOGGER.debug("Event trigger received: %s %s", type(event), event) + func_args = { + "trigger_type": "event", + "event_type": event.event_type, + "context": event.context, + } + func_args.update(event.data) + if self.has_expression(): + if not await self.check_expression_vars(func_args): + return + + await self.dispatch(DispatchData(func_args)) + + async def start(self) -> None: + """Start the event trigger.""" + await super().start() + self.remove_listener_callback = self.dm.hass.bus.async_listen(self.args[0], self._event_callback) + _LOGGER.debug("Event trigger started for event: %s", self.args[0]) + _LOGGER.debug("Remove listener: %s", self.remove_listener_callback) + + async def stop(self) -> None: + """Stop the event trigger.""" + await super().stop() + if self.remove_listener_callback: + self.remove_listener_callback() diff --git a/custom_components/pyscript/decorators/mqtt.py b/custom_components/pyscript/decorators/mqtt.py new file mode 100644 index 00000000..d6019202 --- /dev/null +++ b/custom_components/pyscript/decorators/mqtt.py @@ -0,0 +1,69 @@ +"""Mqtt decorator.""" + +from __future__ import annotations + +import json +import logging + +import voluptuous as vol + +from homeassistant.components import mqtt +from homeassistant.core import CALLBACK_TYPE + +from ..decorator_abc import DispatchData, TriggerDecorator +from .base import AutoKwargsDecorator, ExpressionDecorator + +_LOGGER = logging.getLogger(__name__) + + +class MQTTTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): + """Implementation for @mqtt_trigger.""" + + name = "mqtt_trigger" + args_schema = vol.Schema(vol.All([vol.Coerce(str)], vol.Length(min=1, max=2))) + kwargs_schema = vol.Schema({vol.Optional("encoding", default="utf-8"): str}) + + encoding: str + + remove_listener_callback: CALLBACK_TYPE | None = None + + async def validate(self) -> None: + """Validate the MQTT trigger.""" + await super().validate() + if len(self.args) == 2: + self.create_expression(self.args[1]) + + async def _mqtt_message_handler(self, mqttmsg: mqtt.ReceiveMessage) -> None: + func_args = { + "trigger_type": "mqtt", + "topic": mqttmsg.topic, + "payload": mqttmsg.payload, + "qos": mqttmsg.qos, + "retain": mqttmsg.retain, + } + try: + func_args["payload_obj"] = json.loads(mqttmsg.payload) + except ValueError: + pass + if self.has_expression(): + if not await self.check_expression_vars(func_args): + return + await self.dispatch(DispatchData(func_args)) + + async def start(self) -> None: + """Start the MQTT trigger.""" + await super().start() + topic = self.args[0] + self.remove_listener_callback = await mqtt.async_subscribe( + self.dm.hass, + topic, + self._mqtt_message_handler, + encoding=self.encoding, + qos=0, + ) + + async def stop(self) -> None: + """Stop the MQTT trigger.""" + await super().stop() + if self.remove_listener_callback: + self.remove_listener_callback() diff --git a/custom_components/pyscript/decorators/service.py b/custom_components/pyscript/decorators/service.py new file mode 100644 index 00000000..564192a5 --- /dev/null +++ b/custom_components/pyscript/decorators/service.py @@ -0,0 +1,140 @@ +"""Service decorator implementation.""" + +from __future__ import annotations + +import ast +from collections import OrderedDict +import io +import logging +import typing + +import voluptuous as vol +import yaml + +from homeassistant.const import SERVICE_RELOAD +from homeassistant.core import ServiceCall, SupportsResponse +from homeassistant.helpers.service import async_set_service_schema + +from .. import DOMAIN, SERVICE_JUPYTER_KERNEL_START, AstEval, Function, State +from ..decorator import FunctionDecoratorManager +from ..decorator_abc import Decorator + +_LOGGER = logging.getLogger(__name__) + + +def service_validator(args: list[str]) -> list[str]: + """Validate and normalize service name.""" + if len(args) == 0: + return [] + s = str(args[0]).strip() + + if not isinstance(s, str): + raise vol.Invalid("must be string") + s = s.strip() + if s.count(".") != 1: + raise vol.Invalid("argument 1 should be a string with one period") + domain, name = s.split(".", 1) + return [domain, name] + + +class ServiceDecorator(Decorator): + """Implementation for @service.""" + + name = "service" + args_schema = vol.Schema(vol.All(vol.Length(max=1), service_validator)) + kwargs_schema = vol.Schema( + {vol.Optional("supports_response", default=SupportsResponse.NONE): vol.Coerce(SupportsResponse)} + ) + + description: dict + + async def validate(self) -> None: + """Validate the arguments.""" + await super().validate() + + if len(self.args) != 2: + self.args = [DOMAIN, self.dm.func_name] + # This condition still does not verify the domain. Keep the behavior + # for transition compatibility and revisit it after the legacy + # subsystem is removed. + if self.args[1] in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START): + # Keep this wording for transition compatibility. Once the legacy + # subsystem is removed, update the message and related tests. + raise SyntaxError( + f"function '{self.dm.func_name}' defined in {self.dm.ast_ctx.get_global_ctx_name()}: " + f"@service conflicts with builtin service" + ) + + ast_funcdef = typing.cast(FunctionDecoratorManager, self.dm).eval_func.func_def + desc = ast.get_docstring(ast_funcdef) + if desc is None or desc == "": + desc = f"pyscript function {ast_funcdef.name}()" + desc = desc.lstrip(" \n\r") + if desc.startswith("yaml"): + try: + desc = desc[4:].lstrip(" \n\r") + file_desc = io.StringIO(desc) + self.description = yaml.load(file_desc, Loader=yaml.BaseLoader) or OrderedDict() # noqa: S506 + file_desc.close() + except Exception as exc: + self.dm.logger.error( + "Unable to decode yaml doc_string for %s(): %s", + ast_funcdef.name, + str(exc), + ) + raise exc + else: + fields = OrderedDict() + for arg in ast_funcdef.args.posonlyargs + ast_funcdef.args.args: + fields[arg.arg] = OrderedDict(description=f"argument {arg.arg}") + self.description = {"description": desc, "fields": fields} + + async def _service_callback(self, call: ServiceCall) -> None: + _LOGGER.info("Service callback: %s", call.service) + + # use a new AstEval context so it can run fully independently + # of other instances (except for global_ctx which is common) + global_ctx = self.dm.eval_func.global_ctx + ast_ctx = AstEval(self.dm.name, global_ctx) + Function.install_ast_funcs(ast_ctx) + func_args = { + "trigger_type": "service", + "context": call.context, + } + func_args.update(call.data) + + async def do_service_call(func, ast_ctx, data): + try: + _LOGGER.debug("Service call start: %s", func.name) + return await func.call(ast_ctx, **data) + except Exception as exc: + await self.dm.handle_exception(exc) + return None + + task = Function.create_task(do_service_call(self.dm.eval_func, ast_ctx, func_args)) + await task + return task.result() + + async def start(self) -> None: + """Register the service.""" + domain = self.args[0] + name = self.args[1] + _LOGGER.debug("Registering service: %s.%s", domain, name) + Function.service_register( + self.dm.ast_ctx.name, + domain, + name, + self._service_callback, + self.kwargs.get("supports_response"), + ) + async_set_service_schema(Function.hass, domain, name, self.description) + + # update service params. In the legacy implementation, Pyscript services were registered + # right after the function definition, then decorators were executed, and finally the + # service cache was updated. + await State.get_service_params() + + async def stop(self) -> None: + """Unregister the service.""" + _LOGGER.debug("Unregistering service: %s.%s", self.args[0], self.args[1]) + Function.service_remove(self.dm.ast_ctx.global_ctx.get_name(), self.args[0], self.args[1]) diff --git a/custom_components/pyscript/decorators/state.py b/custom_components/pyscript/decorators/state.py new file mode 100644 index 00000000..7b588b9b --- /dev/null +++ b/custom_components/pyscript/decorators/state.py @@ -0,0 +1,321 @@ +"""State decorators.""" + +import asyncio +import logging +import re +from typing import Any + +import voluptuous as vol + +from homeassistant.helpers import config_validation as cv + +from ..decorator import WaitUntilDecoratorManager +from ..decorator_abc import DecoratorManagerStatus, DispatchData, TriggerDecorator, TriggerHandlerDecorator +from ..state import State +from ..trigger import ident_any_values_changed, ident_values_changed +from .base import AutoKwargsDecorator, ExpressionDecorator + +STATE_RE = re.compile(r"\w+\.\w+(\.((\w+)|\*))?$") + +_LOGGER = logging.getLogger(__name__) + + +class StateActiveDecorator(TriggerHandlerDecorator, ExpressionDecorator): + """Implementation for @state_active.""" + + name = "state_active" + args_schema = vol.Schema( + vol.All( + vol.Length( + min=1, max=1, msg="got 2 arguments, expected 1" + ), # Keep this wording for transition compatibility. + vol.All([str]), + ) + ) + + var_names: set[str] + + async def validate(self) -> None: + """Validate the decorator arguments.""" + await super().validate() + self.create_expression(self.args[0]) + self.var_names = await self._ast_expression.get_names() + + async def handle_dispatch(self, data: DispatchData) -> bool: + """Handle dispatch events.""" + new_vars = data.trigger_context.get("new_vars", {}) + active_vars = State.notify_var_get(self.var_names, new_vars) + return await self.check_expression_vars(active_vars) + + +def _validate_state_trigger_args(args: list[Any]) -> list[str]: + """Validate and normalize @state_trigger positional arguments.""" + if not isinstance(args, list): + raise vol.Invalid("arguments must be a list") + if len(args) == 0: + raise vol.Invalid("needs at least one argument") + + normalized: list[str] = [] + for idx, arg in enumerate(args, start=1): + if isinstance(arg, str): + normalized.append(arg) + continue + if isinstance(arg, (list, set)): + if not all(isinstance(expr, str) for expr in arg): + raise vol.Invalid(f"argument {idx} should be a string, or list, or set") + normalized.extend(list(arg)) + continue + raise vol.Invalid(f"argument {idx} should be a string, or list, or set") + return normalized + + +class StateTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): + """Implementation for @state_trigger.""" + + name = "state_trigger" + args_schema = vol.Schema(vol.All(_validate_state_trigger_args)) + kwargs_schema = vol.Schema( + { + vol.Optional("state_hold"): vol.Any(None, cv.positive_float), + vol.Optional("state_hold_false"): vol.Any(None, cv.positive_float), + vol.Optional("state_check_now"): cv.boolean, + vol.Optional("watch"): vol.Coerce(set[str], msg="should be type list or set"), + vol.Optional("__test_handshake__"): vol.Coerce(list), + } + ) + # kwargs + state_hold: float | None + state_hold_false: float | None + state_check_now: bool | None + __test_handshake__: list[str] | None + + notify_q: asyncio.Queue + in_wait_until_function: bool + cycle_task: asyncio.Task = None + + state_trig_ident: set[str] + state_trig_ident_any: set[str] + true_entered_at: float | None + false_entered_at: float | None + + last_func_args: dict[str, Any] + last_new_vars: dict[str, Any] + + async def validate(self) -> None: + """Validate and normalize arguments.""" + await super().validate() + self.state_trig_ident = set() + self.state_trig_ident_any = set() + + self.in_wait_until_function = isinstance(self.dm, WaitUntilDecoratorManager) + + if self.state_check_now is None and self.in_wait_until_function: + # check by default for task.wait_until + self.state_check_now = True + + state_trig = [] + + for trig in self.args: + if STATE_RE.match(trig): + self.state_trig_ident_any.add(trig) + else: + state_trig.append(trig) + + if len(state_trig) > 0: + if len(state_trig) == 1: + state_trig_expr = state_trig[0] + else: + state_trig_expr = f"any([{', '.join(state_trig)}])" + + self.create_expression(state_trig_expr) + + if self.kwargs.get("watch") is not None: + self.state_trig_ident = set(self.kwargs.get("watch", [])) + else: + if self.has_expression(): + self.state_trig_ident = await self._ast_expression.get_names() + self.state_trig_ident.update(self.state_trig_ident_any) + + _LOGGER.debug("trigger %s: watching vars %s", self.name, self.state_trig_ident) + _LOGGER.debug("trigger %s: any %s", self.name, self.state_trig_ident_any) + if len(self.state_trig_ident) == 0: + self.dm.logger.error( + "trigger %s: @state_trigger is not watching any variables; will never trigger", + self.dm.name, + ) + + def _diff(self, dt: float, now: float) -> str: + if dt is None: + return "None" + return f"{(now - dt):g} ago" + + async def _check_new_state(self, trig_ok: bool) -> None: + now = asyncio.get_running_loop().time() + if _LOGGER.isEnabledFor(logging.DEBUG): + msg = f"check_new_state: {self}" + msg += f"\ntrig_ok: {trig_ok} now {now} func_args: {self.last_func_args} new_vars: {self.last_new_vars}" + if self.true_entered_at: + msg += f"\ntrue_entered_at: {self.true_entered_at}({(now - self.true_entered_at):g} ago)\n" + if self.false_entered_at: + msg += ( + f"\nfalse_entered_at: {self.false_entered_at}({(now - self.false_entered_at):g} ago)\n" + ) + _LOGGER.debug(msg) + + state_hold_false_passed = False + state_hold_true_passed = False + if trig_ok: + if self.state_hold_false is None or not self.has_expression(): + state_hold_false_passed = True + else: + if self.false_entered_at: + false_duration = now - self.false_entered_at + if false_duration >= self.state_hold_false: + state_hold_false_passed = True + _LOGGER.debug( + "state_hold_false passed (%g), reset false_entered_at, %s", false_duration, self + ) + self.false_entered_at = None + + if state_hold_false_passed: + if self.state_hold is None: + state_hold_true_passed = True + else: + if self.true_entered_at: + true_duration = now - self.true_entered_at + if true_duration >= self.state_hold: + state_hold_true_passed = True + self.true_entered_at = None + _LOGGER.debug( + "state_hold passed (%g), reset true_entered_at, %s", true_duration, self + ) + else: + _LOGGER.debug("state_hold started, %s", self) + self.true_entered_at = now + + if state_hold_true_passed: + self.true_entered_at = None + await self.dispatch( + DispatchData(self.last_func_args, trigger_context={"new_vars": self.last_new_vars}) + ) + self.__test_handshake__ = None + else: + self.true_entered_at = None + if self.state_hold_false is not None: + if not self.false_entered_at: + _LOGGER.debug("state_hold_false started, %s", self) + self.false_entered_at = now + + async def _check_state_hold(self) -> None: + if self.true_entered_at is None: + raise RuntimeError(f"state_hold not started for {self}") + + now = asyncio.get_running_loop().time() + true_duration = now - self.true_entered_at + if true_duration >= self.state_hold: + self.true_entered_at = None + await self.dispatch( + DispatchData(self.last_func_args, trigger_context={"new_vars": self.last_new_vars}) + ) + + async def _cycle(self) -> None: + """Run the trigger cycle with state_hold and state_hold_false logic.""" + loop = asyncio.get_running_loop() + + self.true_entered_at = None + self.false_entered_at = None + + self.last_func_args = {"trigger_type": "state"} + self.last_new_vars = {} + + check_state_expr_on_start = self.state_check_now or self.state_hold_false is not None + + if check_state_expr_on_start: + self.last_new_vars = State.notify_var_get(self.state_trig_ident, {}) + trig_ok = await self._is_trig_ok() + + if self.in_wait_until_function and trig_ok and self.state_check_now is True: + self.state_hold_false = None + + if self.state_check_now and self.has_expression(): + await self._check_new_state(trig_ok) + else: + if not trig_ok and self.state_hold_false is not None: + self.false_entered_at = loop.time() + + if self.__test_handshake__ is not None: + # + # used for testing to avoid race conditions + # we use this as a handshake that we are about to + # listen to the queue + # + _LOGGER.debug("__test_handshake__ handshake: %s", self.__test_handshake__) + State.set(self.__test_handshake__[0], self.__test_handshake__[1]) + self.__test_handshake__ = None + + while self.dm.status is DecoratorManagerStatus.RUNNING: + if self.true_entered_at is None: + effective_timeout = None + else: + effective_timeout = self.state_hold + if self.true_entered_at is not None: + effective_timeout -= loop.time() - self.true_entered_at + + if effective_timeout <= 1e-6: + # ignore deltas smaller than 1us. + await self._check_state_hold() + continue + + try: + if effective_timeout is None: + notify_type, notify_info = await self.notify_q.get() + else: + notify_type, notify_info = await asyncio.wait_for(self.notify_q.get(), effective_timeout) + if notify_type != "state": + raise RuntimeError(f"Invalid notify_type {notify_type}, {self}") + self.last_new_vars = notify_info[0] + self.last_func_args = notify_info[1] + + if ident_any_values_changed(self.last_func_args, self.state_trig_ident_any): + trig_ok = True + elif ident_values_changed(self.last_func_args, self.state_trig_ident): + trig_ok = await self._is_trig_ok() + else: + trig_ok = False + await self._check_new_state(trig_ok) + except TimeoutError: + await self._check_state_hold() + + async def _is_trig_ok(self) -> bool: + if self.has_expression(): + return await self.check_expression_vars(self.last_new_vars) + return True + + def _on_task_done(self, task: asyncio.Task) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + self.dm.logger.error("%s failed", self, exc_info=exc) + + async def start(self) -> None: + """Start the trigger.""" + await super().start() + self.notify_q = asyncio.Queue(0) + if not await State.notify_add(self.state_trig_ident, self.notify_q): + self.dm.logger.error( + "trigger %s: @state_trigger is not watching any variables; will never trigger", + self.dm.name, + ) + return + _LOGGER.debug("trigger %s: starting", self.name) + + self.cycle_task = self.dm.hass.async_create_background_task(self._cycle(), repr(self)) + self.cycle_task.add_done_callback(self._on_task_done) + + async def stop(self): + """Stop the trigger.""" + await super().stop() + if self.cycle_task is not None: + self.cycle_task.cancel() + State.notify_del(self.state_trig_ident, self.notify_q) diff --git a/custom_components/pyscript/decorators/task.py b/custom_components/pyscript/decorators/task.py new file mode 100644 index 00000000..1ed18d01 --- /dev/null +++ b/custom_components/pyscript/decorators/task.py @@ -0,0 +1,40 @@ +"""Task decorators.""" + +from __future__ import annotations + +import logging + +import voluptuous as vol + +from homeassistant.helpers import config_validation as cv + +from ..decorator_abc import CallHandlerDecorator, DispatchData +from ..function import Function +from .base import AutoKwargsDecorator + +_LOGGER = logging.getLogger(__name__) + + +class TaskUniqueDecorator(CallHandlerDecorator, AutoKwargsDecorator): + """Implementation for @task_unique.""" + + name = "task_unique" + args_schema = vol.Schema(vol.All([str], vol.Length(min=1, max=1))) + kwargs_schema = vol.Schema({vol.Optional("kill_me", default=False): cv.boolean}) + + kill_me: bool + + async def handle_call(self, data: DispatchData) -> bool: + """Handle call.""" + if self.kill_me: + if Function.unique_name_used(data.call_ast_ctx, self.args[0]): + _LOGGER.debug( + "trigger %s got %s trigger, @task_unique kill_me=True prevented new action", + "notify_type", + self.name, + ) + return False + + task_unique_func = Function.task_unique_factory(data.call_ast_ctx) + await task_unique_func(self.args[0]) + return True diff --git a/custom_components/pyscript/decorators/timing.py b/custom_components/pyscript/decorators/timing.py new file mode 100644 index 00000000..6c2be062 --- /dev/null +++ b/custom_components/pyscript/decorators/timing.py @@ -0,0 +1,163 @@ +"""Time decorators.""" + +from __future__ import annotations + +import asyncio +import datetime as dt +import logging +import time + +import voluptuous as vol + +from homeassistant.helpers import config_validation as cv + +from .. import trigger +from ..decorator import WaitUntilDecoratorManager +from ..decorator_abc import DecoratorManagerStatus, DispatchData, TriggerDecorator, TriggerHandlerDecorator +from .base import AutoKwargsDecorator + +_LOGGER = logging.getLogger(__name__) + + +def dt_now(): + """Return current time.""" + # Keep this wrapper so tests can patch it during the transition. + return trigger.dt_now() + + +class TimeActiveDecorator(TriggerHandlerDecorator, AutoKwargsDecorator): + """Implementation for @time_active.""" + + name = "time_active" + args_schema = vol.Schema(vol.All([vol.Coerce(str)], vol.Length(min=0))) + kwargs_schema = vol.Schema({vol.Optional("hold_off", default=0.0): cv.positive_float}) + + hold_off: float + + last_trig_time: float = 0.0 + + async def handle_dispatch(self, data: DispatchData) -> bool: + """Handle dispatch.""" + if self.last_trig_time > 0.0 and self.hold_off > 0.0: + if time.monotonic() - self.last_trig_time < self.hold_off: + return False + + if len(self.args) > 0: + if "trigger_time" in data.func_args and isinstance(data.func_args["trigger_time"], dt.datetime): + now = data.func_args["trigger_time"] + else: + now = dt_now() + + for time_spec in self.args: + _LOGGER.debug("time_spec %s, %s", time_spec, self) + _LOGGER.debug("time_active now %s, %s", now, self) + if await trigger.TrigTime.timer_active_check(time_spec, now, self.dm.startup_time): + self.last_trig_time = time.monotonic() + return True + return False + + self.last_trig_time = time.monotonic() + return True + + +class TimeTriggerDecorator(TriggerDecorator): + """Implementation for @time_trigger.""" + + name = "time_trigger" + # args_schema = vol.Schema(vol.All([vol.Coerce(str)], vol.Length(min=0))) + args_schema = vol.Schema( + vol.All( + vol.Length(min=0), + vol.All( + [str], msg="argument 2 should be a string" + ), # Keep this wording for transition compatibility. + ) + ) + + run_on_startup: bool = False + run_on_shutdown: bool = False + timespec: list[str] + _cycle_task: asyncio.Task + + async def validate(self) -> None: + """Validate the decorator arguments.""" + await super().validate() + self.timespec = self.args + + if len(self.timespec) == 0: + self.run_on_startup = True + return + + while "startup" in self.timespec: + self.run_on_startup = True + self.timespec.remove("startup") + while "shutdown" in self.timespec: + self.run_on_shutdown = True + self.timespec.remove("shutdown") + + async def _cycle(self): + if self.run_on_startup: + await self.dispatch(DispatchData({"trigger_type": "time", "trigger_time": "startup"})) + + first_run = True + try: + while self.dm.status is DecoratorManagerStatus.RUNNING: + if first_run: + now = self.dm.startup_time + first_run = False + else: + now = dt_now() + + _LOGGER.debug("time_trigger now %s", now) + time_next, time_next_adj = await trigger.TrigTime.timer_trigger_next( + self.timespec, now, self.dm.startup_time + ) + _LOGGER.debug( + "trigger %s time_next = %s, time_next_adj = %s, now = %s", + self.dm.name, + time_next, + time_next_adj, + now, + ) + if time_next is None: + _LOGGER.debug("trigger %s finished", self.name) + if isinstance(self.dm, WaitUntilDecoratorManager): + await self.dispatch(DispatchData({"trigger_type": "none"})) + break + + # replace with homeassistant.helpers.event.async_track_point_in_utc_time? + timeout = (time_next_adj - now).total_seconds() + _LOGGER.debug("%s sleeping for %s seconds", self, timeout) + await asyncio.sleep(timeout) + _LOGGER.debug("%s finish sleeping for %s seconds", self, timeout) + while True: + now = dt_now() + timeout = (time_next_adj - now).total_seconds() + if timeout <= 1e-6: + break + _LOGGER.debug("%s additional sleep for %s seconds", self, timeout) + await asyncio.sleep(timeout) + + await self.dispatch(DispatchData({"trigger_type": "time", "trigger_time": time_next})) + except asyncio.CancelledError: + raise + + async def stop(self): + """Stop the trigger.""" + if self._cycle_task is not None: + self._cycle_task.cancel() + if self.run_on_shutdown: + await self.dispatch(DispatchData({"trigger_type": "time", "trigger_time": "shutdown"})) + + def _on_task_done(self, task: asyncio.Task) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + self.dm.logger.error("%s failed", self, exc_info=exc) + + async def start(self) -> None: + """Start the decorator.""" + await super().start() + self._cycle_task = self.dm.hass.async_create_background_task(self._cycle(), repr(self)) + self._cycle_task.add_done_callback(self._on_task_done) diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py new file mode 100644 index 00000000..d0d449cc --- /dev/null +++ b/custom_components/pyscript/decorators/webhook.py @@ -0,0 +1,84 @@ +"""Webhook decorator.""" + +import logging + +from aiohttp import hdrs +import voluptuous as vol + +from homeassistant.components import webhook +from homeassistant.components.webhook import SUPPORTED_METHODS +from homeassistant.helpers import config_validation as cv + +from ..decorator_abc import DispatchData, TriggerDecorator +from .base import AutoKwargsDecorator, ExpressionDecorator + +_LOGGER = logging.getLogger(__name__) + + +class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): + """Implementation for @webhook_trigger.""" + + name = "webhook_trigger" + args_schema = vol.Schema( + vol.All( + [vol.Coerce(str)], + vol.Length(min=1, max=2, msg="needs at least one argument"), + ) + ) + kwargs_schema = vol.Schema( + { + vol.Optional("local_only", default=True): cv.boolean, + vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]), + } + ) + + webhook_id: str + local_only: bool + methods: set[str] + + async def validate(self): + """Validate the webhook trigger configuration.""" + await super().validate() + self.webhook_id = self.args[0] + + if len(self.args) == 2: + self.create_expression(self.args[1]) + + async def _handler(self, hass, webhook_id, request): + func_args = { + "trigger_type": "webhook", + "webhook_id": webhook_id, + } + + if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): + func_args["payload"] = await request.json() + else: + # Could potentially return multiples of a key - only take the first + payload_multidict = await request.post() + func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + + if self.has_expression(): + if not await self.check_expression_vars(func_args): + return + + await self.dispatch(DispatchData(func_args)) + + async def start(self): + """Start the webhook trigger.""" + await super().start() + webhook.async_register( + self.dm.hass, + "pyscript", # DOMAIN + "pyscript", # NAME + self.webhook_id, + self._handler, + local_only=self.local_only, + allowed_methods=self.methods, + ) + + _LOGGER.debug("webhook trigger %s listening on id %s", self.dm.name, self.webhook_id) + + async def stop(self): + """Stop the webhook trigger.""" + await super().stop() + webhook.async_unregister(self.dm.hass, self.webhook_id) diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 762be4bf..00632b8c 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -327,6 +327,7 @@ def __init__( self.defaults = [] self.kw_defaults = [] self.decorators = [] + self.dm_decorators = [] self.global_names = set() self.nonlocal_names = set() self.local_names = None @@ -613,8 +614,11 @@ async def eval_decorators(self, ast_ctx): dec_other = [] dec_trig = [] + dec_dm = [] for dec in self.func_def.decorator_list: - if ( + if known_dec := await ast_ctx.global_ctx.get_decorator_by_expr(ast_ctx, dec): + dec_dm.append(known_dec) + elif ( isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id in TRIG_SERV_DECORATORS @@ -628,7 +632,7 @@ async def eval_decorators(self, ast_ctx): dec_other.append(await ast_ctx.aeval(dec)) ast_ctx.code_str, ast_ctx.code_list = code_str, code_list - return dec_trig, reversed(dec_other) + return dec_trig, reversed(dec_other), dec_dm async def resolve_nonlocals(self, ast_ctx): """Tag local variables and resolve nonlocals.""" @@ -1150,7 +1154,7 @@ async def executor_wrap(*args, **kwargs): await func.eval_defaults(self) await func.resolve_nonlocals(self) name = func.get_name() - dec_trig, dec_other = await func.eval_decorators(self) + dec_trig, dec_other, dec_dm = await func.eval_decorators(self) self.dec_eval_depth += 1 for dec_func in dec_other: func = await self.call_func(dec_func, None, func) @@ -1159,11 +1163,13 @@ async def executor_wrap(*args, **kwargs): func.set_name(name) func = func.remove_func() dec_trig += func.decorators + dec_dm += func.dm_decorators elif isinstance(func, EvalFunc): func.set_name(name) self.dec_eval_depth -= 1 if isinstance(func, EvalFunc): func.decorators = dec_trig + func.dm_decorators = dec_dm if self.dec_eval_depth == 0: func.trigger_stop() try: @@ -1172,6 +1178,9 @@ async def executor_wrap(*args, **kwargs): self.log_exception(e) func_var = EvalFuncVar(func) func_var.set_ast_ctx(self) + + if len(dec_dm) > 0: + await self.get_global_ctx().create_decorator_manager(dec_dm, self, func_var) else: func_var = EvalFuncVar(func) func_var.set_ast_ctx(self) diff --git a/custom_components/pyscript/global_ctx.py b/custom_components/pyscript/global_ctx.py index f3780656..2b942044 100644 --- a/custom_components/pyscript/global_ctx.py +++ b/custom_components/pyscript/global_ctx.py @@ -1,5 +1,6 @@ """Global context handling.""" +import ast from collections.abc import Awaitable, Callable import logging import os @@ -9,7 +10,9 @@ from homeassistant.config_entries import ConfigEntry from .const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH -from .eval import AstEval, EvalFunc, SymTable +from .decorator import DecoratorRegistry, FunctionDecoratorManager +from .decorator_abc import Decorator, DecoratorManagerStatus +from .eval import AstEval, EvalFunc, EvalFuncVar, SymTable from .function import Function from .trigger import TrigInfo @@ -34,6 +37,8 @@ def __init__( self.global_sym_table: SymTable = global_sym_table if global_sym_table else {} self.triggers: set[EvalFunc] = set() self.triggers_delay_start: set[EvalFunc] = set() + self.dms: set[FunctionDecoratorManager] = set() + self.dms_delay_start: set[FunctionDecoratorManager] = set() self.logger: logging.Logger = logging.getLogger(LOGGER_PATH + "." + name) self.manager = manager self.auto_start: bool = False @@ -61,6 +66,30 @@ def trigger_register(self, func: EvalFunc) -> bool: self.triggers_delay_start.add(func) return False + async def get_decorator_by_expr(self, ast_ctx: AstEval, dec: ast.expr) -> Decorator | None: + """Return decorator instance from an AST decorator expression.""" + return await DecoratorRegistry.get_decorator_by_expr(ast_ctx, dec) + + async def create_decorator_manager( + self, decs: list[Decorator], ast_ctx: AstEval, func_var: EvalFuncVar + ) -> None: + """Create decorator manager from an AST decorator expression.""" + dm = FunctionDecoratorManager(ast_ctx, func_var) + for dec in decs: + dm.add(dec) + + try: + await dm.validate() + if dm.status is DecoratorManagerStatus.VALIDATED: + self.dms.add(dm) + + if self.auto_start: + await dm.start() + else: + self.dms_delay_start.add(dm) + except Exception as exc: + ast_ctx.log_exception(exc) + def trigger_unregister(self, func: EvalFunc) -> None: """Unregister a trigger function.""" self.triggers.discard(func) @@ -76,12 +105,20 @@ def start(self) -> None: func.trigger_start() self.triggers_delay_start = set() + for dm in self.dms_delay_start: + Function.hass.async_create_task(dm.start()) + self.dms_delay_start = set() + def stop(self) -> None: """Stop all triggers and auto_start.""" for func in self.triggers: func.trigger_stop() self.triggers = set() self.triggers_delay_start = set() + for dm in self.dms: + Function.hass.async_create_task(dm.stop()) + self.dms = set() + self.dms_delay_start = set() self.set_auto_start(False) def get_name(self) -> str: diff --git a/custom_components/pyscript/strings.json b/custom_components/pyscript/strings.json index 12c54e83..6f8d38d5 100644 --- a/custom_components/pyscript/strings.json +++ b/custom_components/pyscript/strings.json @@ -6,7 +6,8 @@ "description": "Once you have created an entry, refer to the pyscript docs to learn how to create scripts and functions.", "data": { "allow_all_imports": "Allow All Imports?", - "hass_is_global": "Access hass as a global variable?" + "hass_is_global": "Access hass as a global variable?", + "legacy_decorators": "Use legacy decorators?" } } }, @@ -22,7 +23,8 @@ "title": "Update pyscript configuration", "data": { "allow_all_imports": "Allow All Imports?", - "hass_is_global": "Access hass as a global variable?" + "hass_is_global": "Access hass as a global variable?", + "legacy_decorators": "Use legacy decorators?" } }, "no_ui_configuration_allowed": { diff --git a/custom_components/pyscript/translations/de.json b/custom_components/pyscript/translations/de.json index ff310e52..7301ca09 100644 --- a/custom_components/pyscript/translations/de.json +++ b/custom_components/pyscript/translations/de.json @@ -6,7 +6,8 @@ "description": "Wenn Sie einen Eintrag angelegt haben, können Sie sich die Doku ansehen, um zu lernen wie Sie Scripts und Funktionen erstellen können.", "data": { "allow_all_imports": "Alle Importe erlauben?", - "hass_is_global": "Home Assistant als globale Variable verwenden?" + "hass_is_global": "Home Assistant als globale Variable verwenden?", + "legacy_decorators": "Legacy-Decorators verwenden?" } } }, @@ -22,7 +23,8 @@ "title": "Pyscript configuration aktualisieren", "data": { "allow_all_imports": "Alle Importe erlauben??", - "hass_is_global": "Home Assistant als globale Variable verwenden?" + "hass_is_global": "Home Assistant als globale Variable verwenden?", + "legacy_decorators": "Legacy-Decorators verwenden?" } }, "no_ui_configuration_allowed": { diff --git a/custom_components/pyscript/translations/en.json b/custom_components/pyscript/translations/en.json index 12c54e83..6f8d38d5 100644 --- a/custom_components/pyscript/translations/en.json +++ b/custom_components/pyscript/translations/en.json @@ -6,7 +6,8 @@ "description": "Once you have created an entry, refer to the pyscript docs to learn how to create scripts and functions.", "data": { "allow_all_imports": "Allow All Imports?", - "hass_is_global": "Access hass as a global variable?" + "hass_is_global": "Access hass as a global variable?", + "legacy_decorators": "Use legacy decorators?" } } }, @@ -22,7 +23,8 @@ "title": "Update pyscript configuration", "data": { "allow_all_imports": "Allow All Imports?", - "hass_is_global": "Access hass as a global variable?" + "hass_is_global": "Access hass as a global variable?", + "legacy_decorators": "Use legacy decorators?" } }, "no_ui_configuration_allowed": { diff --git a/custom_components/pyscript/translations/sk.json b/custom_components/pyscript/translations/sk.json index d7630657..3b7766ba 100644 --- a/custom_components/pyscript/translations/sk.json +++ b/custom_components/pyscript/translations/sk.json @@ -6,7 +6,8 @@ "description": "Akonáhle ste vytvorili položku, pozrite si docs naučiť sa, ako vytvárať skripty a funkcie.", "data": { "allow_all_imports": "Povoliť všetky importy?", - "hass_is_global": "Prístup k globálnej premennej?" + "hass_is_global": "Prístup k globálnej premennej?", + "legacy_decorators": "Použiť legacy dekorátory?" } } }, @@ -22,7 +23,8 @@ "title": "Aktualizovať pyscript konfiguráciu", "data": { "allow_all_imports": "povoliť všetky importy?", - "hass_is_global": "Prístup k globálnej premennej?" + "hass_is_global": "Prístup k globálnej premennej?", + "legacy_decorators": "Použiť legacy dekorátory?" } }, "no_ui_configuration_allowed": { diff --git a/custom_components/pyscript/translations/tr.json b/custom_components/pyscript/translations/tr.json index 06cf8320..0b977768 100644 --- a/custom_components/pyscript/translations/tr.json +++ b/custom_components/pyscript/translations/tr.json @@ -6,7 +6,8 @@ "description": "Bir girdi oluşturduktan sonra, betik ve fonksiyon oluşturmayı öğrenmek için dokümantasyona bakabilirsiniz.", "data": { "allow_all_imports": "Tüm içe aktarmalara izin verilsin mi?", - "hass_is_global": "hass'a global değişken olarak erişilsin mi?" + "hass_is_global": "hass'a global değişken olarak erişilsin mi?", + "legacy_decorators": "Legacy dekoratörler kullanılsın mı?" } } }, @@ -22,7 +23,8 @@ "title": "pyscript yapılandırmasını güncelle", "data": { "allow_all_imports": "Tüm içe aktarmalara izin verilsin mi?", - "hass_is_global": "hass'a global değişken olarak erişilsin mi?" + "hass_is_global": "hass'a global değişken olarak erişilsin mi?", + "legacy_decorators": "Legacy dekoratörler kullanılsın mı?" } }, "no_ui_configuration_allowed": { diff --git a/docs/configuration.rst b/docs/configuration.rst index a3d7eb34..e46fcd70 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -7,14 +7,22 @@ Configuration in the Configuration page. Alternatively, for yaml configuration, add ``pyscript:`` to ``/configuration.yaml``. - Pyscript has two optional configuration parameters that allow any python package to be - imported and exposes the ``hass`` variable as a global (both options default to ``false``): + Pyscript has three optional configuration parameters that allow any python package to be + imported, expose the ``hass`` variable as a global, and temporarily switch back to the + legacy decorator subsystem (all three options default to ``false``): .. code:: yaml pyscript: allow_all_imports: true hass_is_global: true + legacy_decorators: true + + Starting with version ``2.0.0``, pyscript uses the new decorator subsystem by default. + If you run into a problem in the new implementation, you can temporarily set + ``legacy_decorators: true`` to switch back to the legacy subsystem. If you do that, + please also file a bug report in the `GitHub issue tracker `__ + so the problem can be fixed. - Add files with a suffix of ``.py`` in the folder ``/pyscript``. - Restart HASS after installing pyscript. diff --git a/docs/reference.rst b/docs/reference.rst index 93a9ed56..8616cee5 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -20,8 +20,9 @@ You can't mix these two methods - your initial choice determines how you should these settings later. If you want to switch configuration methods you will need to uninstall and reinstall pyscript. -Pyscript has two optional configuration parameters that allow any Python package to be -imported and exposes the ``hass`` variable as a global (both options default to ``false``). +Pyscript has three optional configuration parameters that allow any Python package to be +imported, expose the ``hass`` variable as a global, and temporarily switch back to the +legacy decorator subsystem (all three options default to ``false``). Assuming you didn't use the UI to configure pyscript, these can be set in ``/configuration.yaml``: @@ -30,6 +31,13 @@ in ``/configuration.yaml``: pyscript: allow_all_imports: true hass_is_global: true + legacy_decorators: true + +Starting with version ``2.0.0``, pyscript uses the new decorator subsystem by default. +If you find a problem in the new implementation, you can temporarily set +``legacy_decorators: true`` to switch back to the legacy subsystem. If you do, +please also file a bug report in the `GitHub issue tracker `__ +so the problem can be fixed. It is recommended you put your pyscript configuration its own ``yaml`` file in the ``pyscript`` folder. That way changes to the file will be automatically detected and will trigger a reload, @@ -85,8 +93,9 @@ all the application configuration below the ``apps`` key. However, in a future r for more information. Note that ``pyscript.app_config`` is not defined in regular scripts, only in each application's main file. -Note that if you used the UI flow to configure pyscript, the ``allow_all_imports`` and -``hass_is_global`` configuration settings will be ignored in the yaml file. In that case, +Note that if you used the UI flow to configure pyscript, the ``allow_all_imports``, +``hass_is_global`` and ``legacy_decorators`` configuration settings will be ignored in +the yaml file. In that case, you should omit them from the yaml, and just use yaml for pyscript app configuration. At startup, pyscript loads the following files. It also automatically unloads and reloads diff --git a/tests/test_apps_modules.py b/tests/test_apps_modules.py index 5259002f..f02761a2 100644 --- a/tests/test_apps_modules.py +++ b/tests/test_apps_modules.py @@ -187,6 +187,9 @@ async def state_changed(event): hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed) + await hass.async_start() + await hass.async_block_till_done() + assert not hass.services.has_service("pyscript", "func10") assert not hass.services.has_service("pyscript", "func11") assert hass.services.has_service("pyscript", "func1") diff --git a/tests/test_config_flow.py b/tests/test_config_flow.py index 269bd87e..a59ed835 100644 --- a/tests/test_config_flow.py +++ b/tests/test_config_flow.py @@ -6,7 +6,12 @@ import pytest from custom_components.pyscript import PYSCRIPT_SCHEMA -from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, CONF_HASS_IS_GLOBAL, DOMAIN +from custom_components.pyscript.const import ( + CONF_ALLOW_ALL_IMPORTS, + CONF_HASS_IS_GLOBAL, + CONF_LEGACY_DECORATORS, + DOMAIN, +) from homeassistant import data_entry_flow from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER @@ -34,8 +39,10 @@ async def test_user_flow_minimum_fields(hass, pyscript_bypass_setup): assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY assert CONF_ALLOW_ALL_IMPORTS in result["data"] assert CONF_HASS_IS_GLOBAL in result["data"] + assert CONF_LEGACY_DECORATORS in result["data"] assert not result["data"][CONF_ALLOW_ALL_IMPORTS] assert not result["data"][CONF_HASS_IS_GLOBAL] + assert not result["data"][CONF_LEGACY_DECORATORS] @pytest.mark.asyncio @@ -48,13 +55,19 @@ async def test_user_flow_all_fields(hass, pyscript_bypass_setup): assert result["step_id"] == "user" result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True} + result["flow_id"], + user_input={ + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + }, ) assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY assert CONF_ALLOW_ALL_IMPORTS in result["data"] assert result["data"][CONF_ALLOW_ALL_IMPORTS] assert result["data"][CONF_HASS_IS_GLOBAL] + assert result["data"][CONF_LEGACY_DECORATORS] @pytest.mark.asyncio @@ -63,7 +76,11 @@ async def test_user_already_configured(hass, pyscript_bypass_setup): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data={CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}, + data={ + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: False, + }, ) assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY @@ -71,7 +88,11 @@ async def test_user_already_configured(hass, pyscript_bypass_setup): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data={CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}, + data={ + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: False, + }, ) assert result["type"] == data_entry_flow.FlowResultType.ABORT @@ -100,7 +121,11 @@ async def test_import_flow_update_allow_all_imports(hass, pyscript_bypass_setup) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, - data={CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}, + data={ + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + }, ) assert result["type"] == data_entry_flow.FlowResultType.ABORT @@ -162,7 +187,13 @@ async def test_import_flow_update_user(hass, pyscript_bypass_setup): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}), + data=PYSCRIPT_SCHEMA( + { + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + } + ), ) assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY @@ -177,6 +208,7 @@ async def test_import_flow_update_user(hass, pyscript_bypass_setup): assert hass.config_entries.async_entries(DOMAIN)[0].data == { CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, "apps": {"test_app": {"param": 1}}, } @@ -187,7 +219,13 @@ async def test_import_flow_update_import(hass, pyscript_bypass_setup): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, - data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}), + data=PYSCRIPT_SCHEMA( + { + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + } + ), ) assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY @@ -208,7 +246,13 @@ async def test_options_flow_import(hass, pyscript_bypass_setup): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, - data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}), + data=PYSCRIPT_SCHEMA( + { + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + } + ), ) await hass.async_block_till_done() assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY @@ -231,7 +275,13 @@ async def test_options_flow_user_change(hass, pyscript_bypass_setup): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}), + data=PYSCRIPT_SCHEMA( + { + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + } + ), ) await hass.async_block_till_done() assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY @@ -243,7 +293,12 @@ async def test_options_flow_user_change(hass, pyscript_bypass_setup): assert result["step_id"] == "init" result = await hass.config_entries.options.async_configure( - result["flow_id"], user_input={CONF_ALLOW_ALL_IMPORTS: False, CONF_HASS_IS_GLOBAL: False} + result["flow_id"], + user_input={ + CONF_ALLOW_ALL_IMPORTS: False, + CONF_HASS_IS_GLOBAL: False, + CONF_LEGACY_DECORATORS: False, + }, ) await hass.async_block_till_done() @@ -252,6 +307,7 @@ async def test_options_flow_user_change(hass, pyscript_bypass_setup): assert entry.data[CONF_ALLOW_ALL_IMPORTS] is False assert entry.data[CONF_HASS_IS_GLOBAL] is False + assert entry.data[CONF_LEGACY_DECORATORS] is False @pytest.mark.asyncio @@ -260,7 +316,13 @@ async def test_options_flow_user_no_change(hass, pyscript_bypass_setup): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}), + data=PYSCRIPT_SCHEMA( + { + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + } + ), ) await hass.async_block_till_done() assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY @@ -272,7 +334,12 @@ async def test_options_flow_user_no_change(hass, pyscript_bypass_setup): assert result["step_id"] == "init" result = await hass.config_entries.options.async_configure( - result["flow_id"], user_input={CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True} + result["flow_id"], + user_input={ + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + }, ) assert result["type"] == data_entry_flow.FlowResultType.FORM @@ -294,7 +361,13 @@ async def test_config_entry_reload(hass): result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, - data=PYSCRIPT_SCHEMA({CONF_ALLOW_ALL_IMPORTS: True, CONF_HASS_IS_GLOBAL: True}), + data=PYSCRIPT_SCHEMA( + { + CONF_ALLOW_ALL_IMPORTS: True, + CONF_HASS_IS_GLOBAL: True, + CONF_LEGACY_DECORATORS: True, + } + ), ) await hass.async_block_till_done() assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY diff --git a/tests/test_decorator_errors.py b/tests/test_decorator_errors.py index 84fcf173..7fe5c6c8 100644 --- a/tests/test_decorator_errors.py +++ b/tests/test_decorator_errors.py @@ -467,7 +467,4 @@ def func8(): pass """, ) - assert ( - "TypeError: function 'func8' defined in file.hello: {'bad'} aren't valid webhook_trigger methods" - in caplog.text - ) + assert "TypeError: function 'func8' defined in file.hello:" in caplog.text diff --git a/tests/test_decorator_manager.py b/tests/test_decorator_manager.py new file mode 100644 index 00000000..483d3a1d --- /dev/null +++ b/tests/test_decorator_manager.py @@ -0,0 +1,750 @@ +"""Unit tests for decorator managers.""" + +from __future__ import annotations + +import logging +from typing import ClassVar +from unittest.mock import patch + +import pytest +import voluptuous as vol + +from custom_components.pyscript.const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN +from custom_components.pyscript.decorator import ( + DecoratorRegistry, + FunctionDecoratorManager, + WaitUntilDecoratorManager, +) +from custom_components.pyscript.decorator_abc import ( + CallHandlerDecorator, + CallResultHandlerDecorator, + Decorator, + DecoratorManager, + DecoratorManagerStatus, + DispatchData, +) +import custom_components.pyscript.decorators.base as decorators_base_module +from custom_components.pyscript.decorators.base import AutoKwargsDecorator, ExpressionDecorator +from custom_components.pyscript.function import Function +import custom_components.pyscript.global_ctx as global_ctx_module +from custom_components.pyscript.global_ctx import GlobalContext +from homeassistant.core import Context, HomeAssistant + +_MISSING = object() +_REGISTRY_ATTR = "_decorators" +_EXPRESSION_ATTR = "_ast_expression" +_CALL_METHOD = "_call" + + +@pytest.fixture(autouse=True) +def restore_manager_globals(): + """Restore shared class state touched by these unit tests.""" + old_decorators = get_registry_decorators(default=_MISSING) + old_hass = getattr(DecoratorManager, "hass", _MISSING) + + yield + + if old_decorators is _MISSING: + if hasattr(DecoratorRegistry, _REGISTRY_ATTR): + delattr(DecoratorRegistry, _REGISTRY_ATTR) + else: + set_registry_decorators(old_decorators) + + if old_hass is _MISSING: + if hasattr(DecoratorManager, "hass"): + delattr(DecoratorManager, "hass") + else: + DecoratorManager.hass = old_hass + + +class DummyAstCtx: + """Minimal AstEval stub for manager unit tests.""" + + def __init__(self, name: str = "file.hello.func") -> None: + """Initialize a dummy AST context.""" + self.name = name + self.global_ctx = object() + self.logged_exceptions: list[Exception] = [] + self._logger = logging.getLogger(__name__) + + def get_logger(self): + """Return test logger.""" + return self._logger + + def get_global_ctx_name(self) -> str: + """Return global context name.""" + return "file.hello" + + def log_exception(self, exc: Exception) -> None: + """Record exceptions passed to the AST context.""" + self.logged_exceptions.append(exc) + + +class DummyManager(DecoratorManager): + """Concrete manager used to unit-test the abstract base logic.""" + + def __init__(self, ast_ctx: DummyAstCtx, name: str = "file.hello.func") -> None: + """Initialize the dummy manager.""" + super().__init__(ast_ctx, name) + self.dispatched: list[DispatchData] = [] + + async def dispatch(self, data: DispatchData) -> None: + """Store dispatched payloads.""" + self.dispatched.append(data) + + +class RecordingDecorator(Decorator): + """Decorator that records lifecycle calls.""" + + name = "recording" + + label: str + events: list[tuple[str, str]] + validate_exc: Exception | None = None + start_exc: Exception | None = None + stop_exc: Exception | None = None + + async def validate(self) -> None: + """Record validation and optionally fail.""" + self.events.append(("validate", self.label)) + if self.validate_exc is not None: + raise self.validate_exc + + async def start(self) -> None: + """Record startup and optionally fail.""" + self.events.append(("start", self.label)) + if self.start_exc is not None: + raise self.start_exc + + async def stop(self) -> None: + """Record shutdown and optionally fail.""" + self.events.append(("stop", self.label)) + if self.stop_exc is not None: + raise self.stop_exc + + +class CancelCallHandler(CallHandlerDecorator): + """Call handler that cancels the action call.""" + + name = "cancel_call" + seen: list[dict] + + async def handle_call(self, data: DispatchData) -> bool: + """Cancel the action call.""" + self.seen.append(data.func_args.copy()) + return False + + +class RecordingResultHandler(CallResultHandlerDecorator): + """Result handler that stores received results.""" + + name = "record_result" + results: list[object] + + async def handle_call_result(self, data: DispatchData, result: object) -> None: + """Record action result.""" + self.results.append(result) + + +class AutoKwargsTestDecorator(AutoKwargsDecorator): + """Decorator used to test AutoKwargsDecorator behavior.""" + + name = "auto_kwargs_test" + kwargs_schema = vol.Schema( + { + vol.Optional("enabled"): bool, + vol.Optional("count"): int, + vol.Optional("ignored"): str, + } + ) + + enabled: bool | None + count: int | None + + +class ExpressionTestDecorator(ExpressionDecorator): + """Decorator used to test ExpressionDecorator behavior.""" + + name = "expression_test" + + +class FailingExpression: + """Expression stub that always raises during evaluation.""" + + async def eval(self, state_vars: dict[str, object]) -> bool: + """Raise an evaluation error.""" + raise RuntimeError(f"eval failed: {state_vars['value']}") + + +def make_recording_decorator( + label: str, + events: list[tuple[str, str]], + *, + validate_exc: Exception | None = None, + start_exc: Exception | None = None, + stop_exc: Exception | None = None, +) -> RecordingDecorator: + """Create a RecordingDecorator without overriding Decorator.__init__.""" + decorator = RecordingDecorator([], {}) + decorator.label = label + decorator.events = events + decorator.validate_exc = validate_exc + decorator.start_exc = start_exc + decorator.stop_exc = stop_exc + return decorator + + +def make_cancel_call_handler() -> CancelCallHandler: + """Create a canceling call handler.""" + handler = CancelCallHandler([], {}) + handler.seen = [] + return handler + + +def make_recording_result_handler() -> RecordingResultHandler: + """Create a recording result handler.""" + handler = RecordingResultHandler([], {}) + handler.results = [] + return handler + + +def get_registry_decorators(default: object | None = None) -> object | None: + """Return the decorator registry mapping.""" + return getattr(DecoratorRegistry, _REGISTRY_ATTR, default) + + +def set_registry_decorators(decorators: object) -> None: + """Replace the decorator registry mapping.""" + setattr(DecoratorRegistry, _REGISTRY_ATTR, decorators) + + +def set_decorator_ast_expression(decorator: ExpressionDecorator, expression: object) -> None: + """Set the internal AstEval expression for a test decorator.""" + setattr(decorator, _EXPRESSION_ATTR, expression) + + +async def call_function_manager(manager: FunctionDecoratorManager, data: DispatchData) -> None: + """Invoke the protected function-manager call path in tests.""" + await getattr(manager, _CALL_METHOD)(data) + + +class FakeAstEvalForExpression: + """AstEval stub that records create_expression inputs.""" + + instances: ClassVar[list["FakeAstEvalForExpression"]] = [] + + def __init__(self, name: str, global_ctx: object, local_name: str) -> None: + """Initialize the fake AstEval stub.""" + self.name = name + self.global_ctx = global_ctx + self.local_name = local_name + self.parse_calls: list[tuple[str, str]] = [] + self.__class__.instances.append(self) + + def parse(self, expression: str, mode: str) -> None: + """Record parse invocations.""" + self.parse_calls.append((expression, mode)) + + +class DummyEvalFunc: + """Minimal EvalFunc stub for FunctionDecoratorManager tests.""" + + def __init__(self, name: str = "func") -> None: + """Initialize the dummy eval function.""" + self.name = name + self.global_ctx_name = "file.hello" + self.logger = logging.getLogger(__name__) + + +class DummyEvalFuncVar: + """Minimal EvalFuncVar stub for FunctionDecoratorManager tests.""" + + def __init__(self, name: str = "func") -> None: + """Initialize the dummy eval function wrapper.""" + self.func = DummyEvalFunc(name) + + def get_name(self) -> str: + """Return function name.""" + return self.func.name + + +class DummyCallAstCtx: + """Minimal action AstEval stub for manager call tests.""" + + def __init__(self, result: object) -> None: + """Initialize the dummy action context.""" + self.result = result + self.calls: list[tuple[object, object, dict]] = [] + + async def call_func(self, func: object, func_name: object, **kwargs: object) -> object: + """Record the function call and return the configured result.""" + self.calls.append((func, func_name, kwargs)) + return self.result + + +class DummyConfigEntry: + """Minimal config entry stub for GlobalContext tests.""" + + def __init__(self, data: dict) -> None: + """Initialize the dummy config entry.""" + self.data = data + + +class DummyAsyncManager: + """Minimal async manager stub for GlobalContext start/stop tests.""" + + def __init__(self) -> None: + """Initialize the dummy async manager.""" + self.start_calls = 0 + self.stop_calls = 0 + + async def start(self) -> None: + """Record manager start.""" + self.start_calls += 1 + + async def stop(self) -> None: + """Record manager stop.""" + self.stop_calls += 1 + + +class FakeFunctionDecoratorManager: + """Patchable manager stub for GlobalContext.create_decorator_manager tests.""" + + instances: ClassVar[list["FakeFunctionDecoratorManager"]] = [] + status_after_validate: ClassVar[DecoratorManagerStatus] = DecoratorManagerStatus.VALIDATED + validate_exception: ClassVar[Exception | None] = None + + def __init__(self, ast_ctx: DummyAstCtx, func_var: DummyEvalFuncVar) -> None: + """Initialize the fake function decorator manager.""" + self.ast_ctx = ast_ctx + self.func_var = func_var + self.status = DecoratorManagerStatus.INIT + self.added = [] + self.validate_calls = 0 + self.start_calls = 0 + self.stop_calls = 0 + self.__class__.instances.append(self) + + async def validate(self) -> None: + """Record validation and apply the configured result.""" + self.validate_calls += 1 + if self.__class__.validate_exception is not None: + raise self.__class__.validate_exception + self.status = self.__class__.status_after_validate + + def add(self, decorator: Decorator) -> None: + """Record added decorators.""" + self.added.append(decorator) + + async def start(self) -> None: + """Record manager start.""" + self.start_calls += 1 + + async def stop(self) -> None: + """Record manager stop.""" + self.stop_calls += 1 + + +def make_dispatch_data( + func_args: dict[str, object], + *, + call_ast_ctx: DummyCallAstCtx | None = None, + hass_context: Context | None = None, +) -> DispatchData: + """Build DispatchData from test doubles.""" + return DispatchData(func_args, call_ast_ctx=call_ast_ctx, hass_context=hass_context) + + +def setup_global_context_function_hass(hass: HomeAssistant, config_data: dict | None = None) -> None: + """Configure Function.hass prerequisites needed by GlobalContext.""" + hass.data[DOMAIN] = {CONFIG_ENTRY: DummyConfigEntry(config_data or {})} + + +@pytest.mark.asyncio +async def test_decorator_manager_no_decorators_and_accessors(): + """Validate empty-manager lifecycle behavior.""" + dm = DummyManager(DummyAstCtx()) + await dm.validate() + assert dm.status is DecoratorManagerStatus.NO_DECORATORS + + decorators = dm.get_decorators() + decorators.append("sentinel") + assert dm.get_decorators() == [] + + dm.update_status(DecoratorManagerStatus.NO_DECORATORS) + assert dm.status is DecoratorManagerStatus.NO_DECORATORS + + with pytest.raises(RuntimeError, match="Starting not valid"): + await dm.start() + + +@pytest.mark.asyncio +async def test_decorator_manager_start_rolls_back_started_decorators(): + """A later start failure should stop already-started decorators.""" + events: list[tuple[str, str]] = [] + dm = DummyManager(DummyAstCtx()) + first = make_recording_decorator("first", events) + second = make_recording_decorator("second", events, start_exc=RuntimeError("start failed")) + dm.add(first) + dm.add(second) + + await dm.validate() + + with pytest.raises(RuntimeError, match="start failed"): + await dm.start() + + assert ("start", "first") in events + assert ("start", "second") in events + assert ("stop", "first") in events + assert ("stop", "second") not in events + assert dm.status is DecoratorManagerStatus.INVALID + assert dm.startup_time is None + assert dm.get_decorators() == [] + + +@pytest.mark.asyncio +async def test_auto_kwargs_decorator_validate_sets_only_annotated_attrs(): + """AutoKwargsDecorator should materialize only annotated kwargs.""" + dm = DummyManager(DummyAstCtx()) + decorator = AutoKwargsTestDecorator([], {"enabled": True, "ignored": "x"}) + dm.add(decorator) + await decorator.validate() + + assert decorator.enabled is True + assert decorator.count is None + assert not hasattr(decorator, "ignored") + + +@pytest.mark.asyncio +async def test_expression_decorator_requires_expression_before_eval(): + """ExpressionDecorator should raise if no expression was created.""" + dm = DummyManager(DummyAstCtx()) + decorator = ExpressionTestDecorator([], {}) + dm.add(decorator) + + with pytest.raises(AttributeError, match="has no expression defined"): + await decorator.check_expression_vars({}) + + +@pytest.mark.asyncio +async def test_expression_decorator_logs_eval_exceptions_via_manager(): + """ExpressionDecorator should route eval exceptions through the manager.""" + ast_ctx = DummyAstCtx() + dm = DummyManager(ast_ctx) + decorator = ExpressionTestDecorator([], {}) + dm.add(decorator) + set_decorator_ast_expression(decorator, FailingExpression()) + + assert await decorator.check_expression_vars({"value": 7}) is False + assert len(ast_ctx.logged_exceptions) == 1 + assert str(ast_ctx.logged_exceptions[0]) == "eval failed: 7" + + +def test_expression_decorator_create_expression_uses_manager_context(): + """create_expression() should build AstEval with the manager context.""" + FakeAstEvalForExpression.instances = [] + dm = DummyManager(DummyAstCtx()) + decorator = ExpressionTestDecorator([], {}) + dm.add(decorator) + + with ( + patch.object(decorators_base_module, "AstEval", FakeAstEvalForExpression), + patch.object(Function, "install_ast_funcs") as install_ast_funcs, + ): + decorator.create_expression("value > 1") + + assert decorator.has_expression() is True + assert len(FakeAstEvalForExpression.instances) == 1 + ast_eval = FakeAstEvalForExpression.instances[0] + assert ast_eval.name == "file.hello.func expression_test" + assert ast_eval.global_ctx is dm.ast_ctx.global_ctx + assert ast_eval.local_name == dm.name + assert ast_eval.parse_calls == [("value > 1", "eval")] + install_ast_funcs.assert_called_once_with(ast_eval) + + +def test_expression_decorator_create_expression_formats_function_manager_name(): + """create_expression() should use @name() form for function decorator managers.""" + FakeAstEvalForExpression.instances = [] + manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) + decorator = ExpressionTestDecorator([], {}) + manager.add(decorator) + + with ( + patch.object(decorators_base_module, "AstEval", FakeAstEvalForExpression), + patch.object(Function, "install_ast_funcs") as install_ast_funcs, + ): + decorator.create_expression("value > 1") + + assert len(FakeAstEvalForExpression.instances) == 1 + ast_eval = FakeAstEvalForExpression.instances[0] + assert ast_eval.name == "file.hello.func @expression_test()" + assert ast_eval.global_ctx is manager.ast_ctx.global_ctx + assert ast_eval.local_name == manager.name + assert ast_eval.parse_calls == [("value > 1", "eval")] + install_ast_funcs.assert_called_once_with(ast_eval) + + +@pytest.mark.asyncio +async def test_wait_until_rejects_unknown_arguments(hass): + """task.wait_until should reject kwargs that do not map to decorators.""" + DecoratorManager.hass = hass + set_registry_decorators({}) + + with pytest.raises(ValueError, match="Unknown arguments"): + await DecoratorRegistry.wait_until(DummyAstCtx(), unexpected=1) + + +@pytest.mark.asyncio +async def test_wait_until_ignores_dispatch_after_completion(hass): + """Repeated dispatches after completion should be ignored.""" + DecoratorManager.hass = hass + dm = WaitUntilDecoratorManager(DummyAstCtx()) + dm.update_status(DecoratorManagerStatus.RUNNING) + trigger = object() + + await dm.dispatch(DispatchData({"value": 1}, trigger=trigger)) + await dm.dispatch(DispatchData({"value": 2}, trigger=trigger)) + + assert await dm.wait_until() == {"value": 1} + assert dm.status is DecoratorManagerStatus.STOPPED + + +@pytest.mark.asyncio +async def test_wait_until_ignores_exception_after_completion(hass): + """Late exceptions should not override an already completed result.""" + DecoratorManager.hass = hass + dm = WaitUntilDecoratorManager(DummyAstCtx()) + dm.update_status(DecoratorManagerStatus.RUNNING) + trigger = object() + + await dm.dispatch(DispatchData({"value": 1}, trigger=trigger)) + await dm.handle_exception(RuntimeError("late")) + + assert await dm.wait_until() == {"value": 1} + + +@pytest.mark.asyncio +async def test_function_decorator_manager_cancel_calls_result_handlers(hass): + """Canceled calls should still notify result handlers with None.""" + DecoratorManager.hass = hass + manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) + call_handler = make_cancel_call_handler() + result_handler = make_recording_result_handler() + call_ast_ctx = DummyCallAstCtx(result="unused") + manager.add(call_handler) + manager.add(result_handler) + + await call_function_manager( + manager, + make_dispatch_data({"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=Context(id="call-parent")), + ) + + assert call_handler.seen == [{"arg1": 1}] + assert result_handler.results == [None] + assert not call_ast_ctx.calls + + +@pytest.mark.asyncio +async def test_function_decorator_manager_success_calls_result_handlers(hass): + """Successful calls should pass the function result to result handlers.""" + DecoratorManager.hass = hass + manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) + result_handler = make_recording_result_handler() + call_ast_ctx = DummyCallAstCtx(result="ok") + manager.add(result_handler) + hass_context = Context(id="call-parent") + fired_events = [] + + def event_listener(event): + fired_events.append(event) + + hass.bus.async_listen("pyscript_running", event_listener) + + with patch.object(Function, "store_hass_context") as store_hass_context: + await call_function_manager( + manager, make_dispatch_data({"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=hass_context) + ) + await hass.async_block_till_done() + + assert call_ast_ctx.calls == [(manager.eval_func, None, {"arg1": 1})] + assert result_handler.results == ["ok"] + assert len(fired_events) == 1 + assert fired_events[0].data == { + "name": "file_hello_func", + "entity_id": "pyscript.file_hello_func", + "func_args": {"arg1": 1}, + } + store_hass_context.assert_called_once_with(hass_context) + + +def test_decorator_registry_register_requires_name(): + """Registry should reject decorators without a declared name.""" + + class NamelessDecorator(Decorator): + pass + + set_registry_decorators({}) + + with pytest.raises(TypeError, match="Decorator name is required"): + DecoratorRegistry.register(NamelessDecorator) + + +def test_decorator_registry_warns_on_override(caplog): + """Registering the same decorator name twice should warn.""" + + class FirstDecorator(Decorator): + name = "duplicate" + + class SecondDecorator(Decorator): + name = "duplicate" + + set_registry_decorators({}) + + DecoratorRegistry.register(FirstDecorator) + with caplog.at_level(logging.WARNING): + DecoratorRegistry.register(SecondDecorator) + + assert "Overriding decorator: duplicate" in caplog.text + assert get_registry_decorators()["duplicate"] is SecondDecorator + + +def test_global_context_initializes_hass_and_app_config(hass): + """GlobalContext should expose hass and copy app_config when configured.""" + setup_global_context_function_hass(hass, {CONF_HASS_IS_GLOBAL: True}) + app_config = {"name": "demo"} + + with patch.object(Function, "hass", hass): + global_ctx = GlobalContext("file.hello", app_config=app_config) + + assert global_ctx.global_sym_table["hass"] is hass + assert global_ctx.global_sym_table["pyscript.app_config"] == {"name": "demo"} + assert global_ctx.global_sym_table["pyscript.app_config"] is not app_config + + +@pytest.mark.asyncio +async def test_global_context_start_and_stop_schedule_decorator_managers(hass): + """start() and stop() should fan out to delayed decorator managers.""" + setup_global_context_function_hass(hass) + + with patch.object(Function, "hass", hass): + global_ctx = GlobalContext("file.hello") + manager = DummyAsyncManager() + + global_ctx.dms.add(manager) + global_ctx.dms_delay_start.add(manager) + + global_ctx.start() + await hass.async_block_till_done() + + assert manager.start_calls == 1 + assert global_ctx.dms_delay_start == set() + + global_ctx.stop() + await hass.async_block_till_done() + + assert manager.stop_calls == 1 + assert global_ctx.dms == set() + assert global_ctx.dms_delay_start == set() + assert global_ctx.auto_start is False + + +@pytest.mark.asyncio +async def test_global_context_create_decorator_manager_delays_or_autostarts(hass): + """Validated decorator managers should be delayed or started based on auto_start.""" + setup_global_context_function_hass(hass) + FakeFunctionDecoratorManager.instances = [] + FakeFunctionDecoratorManager.status_after_validate = DecoratorManagerStatus.VALIDATED + FakeFunctionDecoratorManager.validate_exception = None + delayed_ast_ctx = DummyAstCtx("file.hello.func_delayed") + immediate_ast_ctx = DummyAstCtx("file.hello.func_immediate") + func_var = DummyEvalFuncVar() + decorators = [make_recording_decorator("one", [])] + + with ( + patch.object(Function, "hass", hass), + patch.object(global_ctx_module, "FunctionDecoratorManager", FakeFunctionDecoratorManager), + ): + delayed_ctx = GlobalContext("file.hello") + await delayed_ctx.create_decorator_manager(decorators, delayed_ast_ctx, func_var) + + immediate_ctx = GlobalContext("file.hello2") + immediate_ctx.set_auto_start(True) + await immediate_ctx.create_decorator_manager(decorators, immediate_ast_ctx, func_var) + + assert len(FakeFunctionDecoratorManager.instances) == 2 + delayed_dm = FakeFunctionDecoratorManager.instances[0] + immediate_dm = FakeFunctionDecoratorManager.instances[1] + + assert delayed_dm.added == decorators + assert delayed_dm.validate_calls == 1 + assert delayed_dm.start_calls == 0 + assert delayed_dm in delayed_ctx.dms + assert delayed_dm in delayed_ctx.dms_delay_start + + assert immediate_dm.added == decorators + assert immediate_dm.validate_calls == 1 + assert immediate_dm.start_calls == 1 + assert immediate_dm in immediate_ctx.dms + assert immediate_dm not in immediate_ctx.dms_delay_start + + +@pytest.mark.asyncio +async def test_global_context_create_decorator_manager_ignores_non_validated_status(hass): + """Managers that do not validate successfully should not be registered.""" + setup_global_context_function_hass(hass) + FakeFunctionDecoratorManager.instances = [] + FakeFunctionDecoratorManager.status_after_validate = DecoratorManagerStatus.NO_DECORATORS + FakeFunctionDecoratorManager.validate_exception = None + ast_ctx = DummyAstCtx() + + with ( + patch.object(Function, "hass", hass), + patch.object(global_ctx_module, "FunctionDecoratorManager", FakeFunctionDecoratorManager), + ): + global_ctx = GlobalContext("file.hello") + await global_ctx.create_decorator_manager( + [make_recording_decorator("one", [])], ast_ctx, DummyEvalFuncVar() + ) + + assert FakeFunctionDecoratorManager.instances[0].validate_calls == 1 + assert global_ctx.dms == set() + assert global_ctx.dms_delay_start == set() + assert not ast_ctx.logged_exceptions + + +@pytest.mark.asyncio +async def test_global_context_create_decorator_manager_logs_validation_exception(hass): + """Validation exceptions should be logged on the AST context.""" + setup_global_context_function_hass(hass) + FakeFunctionDecoratorManager.instances = [] + FakeFunctionDecoratorManager.status_after_validate = DecoratorManagerStatus.VALIDATED + FakeFunctionDecoratorManager.validate_exception = RuntimeError("validation failed") + ast_ctx = DummyAstCtx() + + with ( + patch.object(Function, "hass", hass), + patch.object(global_ctx_module, "FunctionDecoratorManager", FakeFunctionDecoratorManager), + ): + global_ctx = GlobalContext("file.hello") + await global_ctx.create_decorator_manager( + [make_recording_decorator("one", [])], ast_ctx, DummyEvalFuncVar() + ) + + assert FakeFunctionDecoratorManager.instances[0].validate_calls == 1 + assert len(ast_ctx.logged_exceptions) == 1 + assert str(ast_ctx.logged_exceptions[0]) == "validation failed" + assert global_ctx.dms == set() + assert global_ctx.dms_delay_start == set() + + +def test_decorator_registry_init_legacy_mode_skips_new_registry(hass, caplog, monkeypatch): + """Legacy-mode env should disable the new decorator registry.""" + monkeypatch.setenv("NODM", "1") + + with patch.object(Function, "register_ast") as register_ast: + DecoratorRegistry.init(hass) + + assert "Using legacy decorators" in caplog.text + register_ast.assert_not_called() + assert not get_registry_decorators() diff --git a/tests/test_init.py b/tests/test_init.py index ad5ab630..8c46232e 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -15,7 +15,7 @@ from custom_components.pyscript.global_ctx import GlobalContextMgr from custom_components.pyscript.state import State from homeassistant import loader -from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED +from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.core import Context from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.setup import async_setup_component @@ -61,6 +61,8 @@ async def state_changed(event): await notify_q.put(value) hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed) + await hass.async_start() + await hass.async_block_till_done() async def wait_until_done(notify_q): @@ -402,7 +404,6 @@ def func5(var_name=None, value=None): # # first time: fire event to startup triggers and run func_startup_sync # - hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) for i in range(6): if i & 1: seq_num = 10 diff --git a/tests/test_unit_eval.py b/tests/test_unit_eval.py index 7dad4de7..8638e181 100644 --- a/tests/test_unit_eval.py +++ b/tests/test_unit_eval.py @@ -6,6 +6,7 @@ import pytest from pytest_homeassistant_custom_component.common import MockConfigEntry +from custom_components.pyscript import DecoratorRegistry from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, CONFIG_ENTRY, DOMAIN from custom_components.pyscript.eval import AstEval, EvalExceptionFormatter from custom_components.pyscript.function import Function @@ -1674,6 +1675,7 @@ async def test_eval(hass): State.init(hass) State.register_functions() TrigTime.init(hass) + DecoratorRegistry.init(hass) for test_data in evalTests: await run_one_test(test_data)