diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py index 33af0d3..3cf50ba 100644 --- a/custom_components/pyscript/decorator.py +++ b/custom_components/pyscript/decorator.py @@ -250,7 +250,7 @@ def __init__(self, ast_ctx: AstEval, eval_func_var: EvalFuncVar) -> None: def on_func_var_deleted(): if self.status is DecoratorManagerStatus.RUNNING: - self.hass.async_create_task(self.stop()) + self.hass.async_create_task(self.safe_await(self.stop())) weakref.finalize(eval_func_var, on_func_var_deleted) @@ -261,9 +261,8 @@ async def _call(self, data: DispatchData) -> None: for handler_dec in handlers: if await handler_dec.handle_call(data) is False: self.logger.debug("Calling canceled by %s", handler_dec) - # notify handlers with "None" for result_handler_dec in result_handlers: - await result_handler_dec.handle_call_result(data, None) + await self.safe_await(result_handler_dec.handle_call_canceled(data)) return # Fire an event indicating that pyscript is running # Note: the event must have an entity_id for logbook to work correctly. @@ -277,10 +276,14 @@ async def _call(self, data: DispatchData) -> None: try: result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) - for result_handler_dec in result_handlers: - await result_handler_dec.handle_call_result(data, result) except Exception as e: + for result_handler_dec in result_handlers: + await self.safe_await(result_handler_dec.handle_call_exception(data, e)) await self.handle_exception(e) + return + + for result_handler_dec in result_handlers: + await self.safe_await(result_handler_dec.handle_call_result(data, result)) async def dispatch(self, data: DispatchData) -> None: """Handle a trigger dispatch: run guards, create a context, and invoke the function.""" @@ -290,6 +293,8 @@ async def dispatch(self, data: DispatchData) -> None: for dec in decorators: if await dec.handle_dispatch(data) is False: self.logger.debug("Trigger not active due to %s", dec) + for result_handler_dec in self.get_decorators(CallResultHandlerDecorator): + await self.safe_await(result_handler_dec.handle_call_canceled(data)) return action_ast_ctx = AstEval( diff --git a/custom_components/pyscript/decorator_abc.py b/custom_components/pyscript/decorator_abc.py index 4775317..c762dba 100644 --- a/custom_components/pyscript/decorator_abc.py +++ b/custom_components/pyscript/decorator_abc.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Awaitable from dataclasses import dataclass, field from enum import StrEnum import logging @@ -179,8 +180,7 @@ async def start(self): try: await decorator.start() started.append(decorator) - except Exception as err: - self.logger.exception("%s start failed: %s", self, err) + except Exception: for started_dec in started: await self._stop_decorator(started_dec) self.startup_time = None @@ -209,6 +209,19 @@ async def handle_exception(self, exc: Exception) -> None: """Handle a decorator exception.""" self.ast_ctx.log_exception(exc) + async def safe_await(self, coro: Awaitable[Any]) -> None: + """ + Await a coroutine, routing (but not propagating) bugs through ``handle_exception``. + + Intended for extension points where a defective subclass shouldn't break + sibling work: the exception surfaces in the same place as user-code errors, + and the caller carries on. + """ + try: + await coro + except Exception as err: + await self.handle_exception(err) + @abstractmethod async def dispatch(self, data: DispatchData) -> None: """Dispatch a trigger call.""" @@ -281,3 +294,11 @@ class CallResultHandlerDecorator(Decorator, ABC): @abstractmethod async def handle_call_result(self, data: DispatchData, result: Any) -> None: """Handle an action call result.""" + + async def handle_call_exception(self, data: DispatchData, exc: Exception) -> None: + """Handle an exception raised by the action call. Default: forward as None result.""" + await self.handle_call_result(data, None) + + async def handle_call_canceled(self, data: DispatchData) -> None: + """Handle a canceled action call (skipped by a handler or trigger). Default: forward as None result.""" + await self.handle_call_result(data, None) diff --git a/custom_components/pyscript/decorators/__init__.py b/custom_components/pyscript/decorators/__init__.py index f21b9a9..c5f0ad3 100644 --- a/custom_components/pyscript/decorators/__init__.py +++ b/custom_components/pyscript/decorators/__init__.py @@ -6,7 +6,7 @@ from .state import StateActiveDecorator, StateTriggerDecorator from .task import TaskUniqueDecorator from .timing import TimeActiveDecorator, TimeTriggerDecorator -from .webhook import WebhookTriggerDecorator +from .webhook import WebhookHandlerDecorator, WebhookTriggerDecorator DECORATORS = [ StateTriggerDecorator, @@ -17,5 +17,6 @@ EventTriggerDecorator, MQTTTriggerDecorator, WebhookTriggerDecorator, + WebhookHandlerDecorator, ServiceDecorator, ] diff --git a/custom_components/pyscript/decorators/base.py b/custom_components/pyscript/decorators/base.py index 213a67e..3babc24 100644 --- a/custom_components/pyscript/decorators/base.py +++ b/custom_components/pyscript/decorators/base.py @@ -2,7 +2,7 @@ from abc import ABC import logging -from typing import Any +from typing import Any, ClassVar import voluptuous as vol @@ -16,13 +16,25 @@ class AutoKwargsDecorator(Decorator, ABC): """Mixin that copies validated kwargs into instance attributes based on annotations.""" + _auto_kw_attrs: ClassVar[frozenset[str]] = frozenset() + + def __init_subclass__(cls, **kwargs): + """Collect names of typed attributes declared on subclasses up to (but not including) Decorator.""" + super().__init_subclass__(**kwargs) + attrs: set[str] = set() + for klass in cls.mro(): + if klass is Decorator: + break + attrs.update(getattr(klass, "__annotations__", {})) + cls._auto_kw_attrs = frozenset(attrs) + async def validate(self) -> None: """Run base validation and materialize annotated kwargs as attributes.""" await super().validate() for k in self.__class__.kwargs_schema.schema: if isinstance(k, vol.Marker): k = k.schema - if k in self.__class__.__annotations__: + if k in self._auto_kw_attrs: setattr(self, k, self.kwargs.get(k, None)) diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py index 3db0a09..97b4e86 100644 --- a/custom_components/pyscript/decorators/webhook.py +++ b/custom_components/pyscript/decorators/webhook.py @@ -2,26 +2,34 @@ from __future__ import annotations +from abc import ABC +import asyncio +from collections.abc import Awaitable, Callable +from http import HTTPStatus import logging -from typing import ClassVar +from typing import Any, ClassVar from aiohttp import hdrs +from aiohttp.web import Request, Response import voluptuous as vol from homeassistant.components import webhook from homeassistant.components.webhook import SUPPORTED_METHODS +from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.json import json_dumps -from ..decorator_abc import DispatchData, TriggerDecorator +from ..decorator_abc import CallResultHandlerDecorator, DispatchData, TriggerDecorator from .base import AutoKwargsDecorator, ExpressionDecorator _LOGGER = logging.getLogger(__name__) +_WEBHOOK_RESULT_FUTURE = "webhook_result_future" -class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): - """Implementation for @webhook_trigger.""" - name = "webhook_trigger" +class _WebhookDecoratorBase(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator, ABC): + """Base class for @webhook_trigger and @webhook_handler.""" + args_schema = vol.Schema( vol.All( [vol.Coerce(str)], @@ -31,7 +39,12 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD kwargs_schema = vol.Schema( { vol.Optional("local_only", default=True): cv.boolean, - vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]), + vol.Optional("methods", default={hdrs.METH_POST, hdrs.METH_PUT}): vol.All( + vol.Coerce(list), + [vol.In(SUPPORTED_METHODS)], + vol.Length(min=1, msg="needs at least one HTTP method"), + vol.Coerce(set), + ), } ) @@ -39,10 +52,8 @@ class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsD local_only: bool methods: set[str] - webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {} - async def validate(self): - """Validate the webhook trigger configuration.""" + """Validate the webhook configuration.""" await super().validate() self.webhook_id = self.args[0] @@ -50,19 +61,46 @@ async def validate(self): self.create_expression(self.args[1]) @staticmethod - async def _handler(_hass, webhook_id, request): - func_args = { + async def _build_func_args(webhook_id: str, request: Request) -> dict[str, Any]: + """Build the standard webhook function kwargs from an incoming request.""" + if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): + payload = await request.json() + else: + payload_multidict = await request.post() + payload = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + return { "trigger_type": "webhook", "webhook_id": webhook_id, "request": request, + "payload": payload, } - if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): - func_args["payload"] = await request.json() - else: - # Could potentially return multiples of a key - only take the first - payload_multidict = await request.post() - func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + def _register_webhook( + self, + handler: Callable[[HomeAssistant, str, Request], Awaitable[Response | None]], + ) -> None: + """Register self.webhook_id with Home Assistant, dispatching to ``handler``.""" + webhook.async_register( + self.dm.hass, + "pyscript", # DOMAIN + "pyscript", # NAME + self.webhook_id, + handler, + local_only=self.local_only, + allowed_methods=self.methods, + ) + + +class WebhookTriggerDecorator(_WebhookDecoratorBase): + """Implementation for @webhook_trigger.""" + + name = "webhook_trigger" + + webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {} + + @staticmethod + async def _handler(_hass: HomeAssistant, webhook_id: str, request: Request) -> None: + func_args = await WebhookTriggerDecorator._build_func_args(webhook_id, request) for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy(): trigger_args = func_args.copy() @@ -71,22 +109,21 @@ async def _handler(_hass, webhook_id, request): continue await trigger.dispatch(DispatchData(trigger_args)) - @staticmethod - def _add_trigger(trigger: WebhookTriggerDecorator) -> None: - webhook_id = trigger.webhook_id - if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers: - webhook.async_register( - trigger.dm.hass, - "pyscript", # DOMAIN - "pyscript", # NAME - webhook_id, - WebhookTriggerDecorator._handler, - local_only=trigger.local_only, - allowed_methods=trigger.methods, - ) - WebhookTriggerDecorator.webhook_id2triggers[webhook_id] = set() + def _add_trigger(self) -> None: + triggers = WebhookTriggerDecorator.webhook_id2triggers.get(self.webhook_id) + if triggers is None: + self._register_webhook(WebhookTriggerDecorator._handler) + WebhookTriggerDecorator.webhook_id2triggers[self.webhook_id] = {self} + return - WebhookTriggerDecorator.webhook_id2triggers[webhook_id].add(trigger) + existing = next(iter(triggers)) + if existing.local_only != self.local_only or existing.methods != self.methods: + raise ValueError( + f"'{self.dm.func_name}' @webhook_trigger for '{self.webhook_id}' conflicts with existing " + f"'{existing.dm.ast_ctx.get_global_ctx_name()}.{existing.dm.func_name}' " + f"(local_only={existing.local_only}, methods={existing.methods})" + ) + triggers.add(self) @staticmethod def _remove_trigger(trigger: WebhookTriggerDecorator) -> None: @@ -103,7 +140,7 @@ def _remove_trigger(trigger: WebhookTriggerDecorator) -> None: async def start(self): """Start the webhook trigger.""" await super().start() - self._add_trigger(self) + self._add_trigger() _LOGGER.debug("webhook trigger %s listening on id %s", self.dm.name, self.webhook_id) @@ -111,3 +148,113 @@ async def stop(self): """Stop the webhook trigger.""" await super().stop() self._remove_trigger(self) + + +class WebhookHandlerDecorator(_WebhookDecoratorBase, CallResultHandlerDecorator): + """ + Implementation for @webhook_handler. + + Like @webhook_trigger, but the function's return value becomes the HTTP + response. Only one handler can be registered per webhook_id. + """ + + name = "webhook_handler" + kwargs_schema = _WebhookDecoratorBase.kwargs_schema.extend( + {vol.Optional("timeout", default=10.0): vol.All(vol.Coerce(float), vol.Range(min=0))} + ) + + timeout: float + + async def _to_response(self, result: Any) -> Response: + """Convert a user-returned value into an aiohttp Response.""" + if result is None: + return Response(status=HTTPStatus.OK) + if isinstance(result, Response): + return result + if isinstance(result, str): + return Response(text=result) + if isinstance(result, bytes): + return Response(body=result) + if isinstance(result, tuple) and len(result) == 2 and isinstance(result[0], int): + status, body = result + response = await self._to_response(body) + response.set_status(status) + return response + if isinstance(result, (dict, list)): + try: + body = json_dumps(result) + except (TypeError, ValueError) as exc: + await self.dm.handle_exception(exc) + return Response(status=HTTPStatus.INTERNAL_SERVER_ERROR) + return Response(text=body, content_type="application/json") + _LOGGER.warning("@webhook_handler returned unsupported type %s", type(result).__name__) + return Response(status=HTTPStatus.INTERNAL_SERVER_ERROR) + + async def _handler(self, _hass: HomeAssistant, webhook_id: str, request: Request) -> Response: + try: + func_args = await self._build_func_args(webhook_id, request) + except ValueError: + # Body could not be parsed (e.g. malformed JSON). Tell the caller their + # request was bad rather than silently returning 200 like webhook_trigger. + _LOGGER.warning("webhook %s received an unparsable request body", webhook_id) + return Response(status=HTTPStatus.BAD_REQUEST) + + if self.has_expression() and not await self.check_expression_vars(func_args): + return Response(status=HTTPStatus.FORBIDDEN) + + future = self.dm.hass.loop.create_future() + data = DispatchData(func_args, trigger_context={_WEBHOOK_RESULT_FUTURE: future}) + await self.dispatch(data) + + try: + result = await asyncio.wait_for(future, timeout=self.timeout) + except TimeoutError: + _LOGGER.warning( + "webhook_handler %s on %s timed out after %ss", + self.dm.name, + webhook_id, + self.timeout, + ) + return Response(status=HTTPStatus.GATEWAY_TIMEOUT) + + try: + return await self._to_response(result) + except Exception as exc: + await self.dm.handle_exception(exc) + return Response(status=HTTPStatus.INTERNAL_SERVER_ERROR) + + async def handle_call_result(self, data: DispatchData, result: Any) -> None: + """Forward the function result to the awaiting webhook request.""" + if data.trigger is not self: + return + self._resolve_future(data, result) + + async def handle_call_exception(self, data: DispatchData, exc: Exception) -> None: + """Return a 500 response when the user function raised.""" + if data.trigger is not self: + return + self._resolve_future(data, Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)) + + async def handle_call_canceled(self, data: DispatchData) -> None: + """Return a 503 response when the call was canceled by a guard (e.g. @task_unique, @state_active).""" + if data.trigger is not self: + return + self._resolve_future(data, Response(status=HTTPStatus.SERVICE_UNAVAILABLE)) + + @staticmethod + def _resolve_future(data: DispatchData, result: Any) -> None: + future = data.trigger_context.get(_WEBHOOK_RESULT_FUTURE) + if future is not None and not future.done(): + future.set_result(result) + + async def start(self): + """Start the webhook handler.""" + await super().start() + self._register_webhook(self._handler) + + _LOGGER.debug("webhook handler %s listening on id %s", self.dm.name, self.webhook_id) + + async def stop(self): + """Stop the webhook handler.""" + await super().stop() + webhook.async_unregister(self.dm.hass, self.webhook_id) diff --git a/custom_components/pyscript/global_ctx.py b/custom_components/pyscript/global_ctx.py index 6470903..0eb5c47 100644 --- a/custom_components/pyscript/global_ctx.py +++ b/custom_components/pyscript/global_ctx.py @@ -79,17 +79,13 @@ async def create_decorator_manager( for dec in decs: dm.add(dec) - try: - await dm.validate() - if dm.status is DecoratorManagerStatus.VALIDATED: - self.dms.add(dm) - - if self.auto_start: - await dm.start() - else: - self.dms_delay_start.add(dm) - except Exception as exc: - ast_ctx.log_exception(exc) + await dm.safe_await(dm.validate()) + if dm.status is DecoratorManagerStatus.VALIDATED: + self.dms.add(dm) + if self.auto_start: + await dm.safe_await(dm.start()) + else: + self.dms_delay_start.add(dm) def trigger_unregister(self, func: EvalFunc) -> None: """Unregister a trigger function.""" @@ -107,7 +103,7 @@ def start(self) -> None: self.triggers_delay_start = set() for dm in self.dms_delay_start: - Function.hass.async_create_task(dm.start()) + Function.hass.async_create_task(dm.safe_await(dm.start())) self.dms_delay_start = set() def stop(self) -> None: @@ -117,7 +113,7 @@ def stop(self) -> None: self.triggers = set() self.triggers_delay_start = set() for dm in self.dms: - Function.hass.async_create_task(dm.stop()) + Function.hass.async_create_task(dm.safe_await(dm.stop())) self.dms = set() self.dms_delay_start = set() self.set_auto_start(False) diff --git a/custom_components/pyscript/stubs/pyscript_builtins.py b/custom_components/pyscript/stubs/pyscript_builtins.py index ea75580..deb241e 100644 --- a/custom_components/pyscript/stubs/pyscript_builtins.py +++ b/custom_components/pyscript/stubs/pyscript_builtins.py @@ -143,6 +143,46 @@ def webhook_trigger( ... +def webhook_handler( + webhook_id: str, + str_expr: str | None = None, + local_only: bool = True, + methods: set[SUPPORTED_METHODS] | list[SUPPORTED_METHODS] = {"POST", "PUT"}, + timeout: int | float = 10.0, + kwargs: dict | None = None, +) -> Callable[..., Any]: + """Handle a webhook request and return the HTTP response from the function. + + Like ``@webhook_trigger`` but the decorated function's return value becomes the HTTP + response. Only one handler can be registered per ``webhook_id``. + + Args: + webhook_id: Webhook id to listen to; must not be registered by another handler or trigger. + str_expr: Optional expression evaluated against ``trigger_type``, ``webhook_id``, ``request``, and ``payload``. + local_only: If False, allow requests from anywhere on the internet. + methods: HTTP methods to allow. + timeout: Seconds to wait for the function before returning ``504 Gateway Timeout``. + kwargs: Extra keyword arguments merged into each invocation. + + Trigger kwargs are identical to ``@webhook_trigger``. + + Return value mapping (function -> HTTP response): + - ``None`` or no return -> ``200 OK`` + - ``str`` -> ``200`` with text body + - ``bytes`` -> ``200`` with raw body + - ``dict`` / ``list`` -> ``200`` with JSON body + - ``(status, body)`` tuple -> ``status`` from tuple, ``body`` mapped recursively + - ``aiohttp.web.Response`` -> returned as-is + - any other type -> ``500`` with a warning (use a tuple or a Response instead) + + A malformed request body yields ``400 Bad Request``, a falsy ``str_expr`` guard + yields ``403 Forbidden``, an uncaught exception in the function yields + ``500 Internal Server Error``, and a call canceled by another decorator's guard + (e.g. ``@task_unique`` or ``@state_active``) yields ``503 Service Unavailable``. + """ + ... + + def pyscript_compile() -> Callable[..., Any]: """Compile the wrapped function into native (synchronous) Python. diff --git a/docs/reference.rst b/docs/reference.rst index 3b7c587..1c384a7 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -917,6 +917,66 @@ To validate an HMAC signature on incoming requests, declare ``request`` in the f NOTE: A webhook_id can only be used by either a built-in Home Assistant automation or pyscript, but not both. Trying to use the same webhook_id in both will result in an error. +@webhook_handler +^^^^^^^^^^^^^^^^ + +.. code:: python + + @webhook_handler(webhook_id, str_expr=None, local_only=True, methods={"POST", "PUT"}, timeout=10, kwargs=None) + +``@webhook_handler`` is like ``@webhook_trigger`` but the decorated function's **return value** becomes the HTTP response. Use it when the caller of the webhook needs more than the default ``200 OK`` — for example to return a status code, a JSON body, or a custom response with headers. + +Only one handler can be registered per ``webhook_id``. Attempting to register a second ``@webhook_handler`` (or a ``@webhook_trigger``) for the same id raises an error at script load time. + +Arguments ``webhook_id``, ``str_expr``, ``local_only``, ``methods``, and ``kwargs`` behave exactly like ``@webhook_trigger``. The decorated function receives the same kwargs (``trigger_type``, ``webhook_id``, ``payload``, ``request``). The additional ``timeout`` argument (seconds, default ``10``) caps how long the webhook will wait for the function to return before responding with ``504 Gateway Timeout``. + +The function's return value is mapped to an HTTP response as follows: + +- ``None`` (or no explicit return) — ``200 OK``, empty body. +- ``str`` — ``200 OK`` with the string as the text body. +- ``bytes`` — ``200 OK`` with the raw bytes as the body. +- ``dict`` or ``list`` — ``200 OK`` with the value serialized as JSON (``Content-Type: application/json``). If the value isn't JSON-serializable, the handler returns ``500`` and logs a warning. +- ``(status, body)`` tuple — ``status`` becomes the HTTP status, ``body`` is mapped recursively using the same rules (so e.g. ``(201, {"id": 7})`` returns ``201`` with a JSON body, and ``(204, None)`` returns ``204`` with an empty body). +- ``aiohttp.web.Response`` — returned as-is, so you can fully customize status, headers, and body. +- Any other type (a bare ``int``, custom object, …) is unsupported: the handler returns ``500`` and logs a warning. Use a ``(status, body)`` tuple or an ``aiohttp.web.Response`` instead. + +Error responses generated by the handler itself: + +- A malformed request body (e.g. invalid JSON for a ``Content-Type: application/json`` request) returns ``400 Bad Request`` without invoking the decorated function. +- A ``str_expr`` guard that evaluates to falsy returns ``403 Forbidden`` without invoking the decorated function. +- An uncaught exception inside the decorated function returns ``500 Internal Server Error``. The exception is still logged through pyscript's normal error reporting. +- A call canceled by another decorator's guard (for example a ``@task_unique`` collision or a falsy ``@state_active`` expression) returns ``503 Service Unavailable``. +- A function that doesn't finish within ``timeout`` seconds returns ``504 Gateway Timeout``. The pending function task continues running in the background. + +A simple JSON-echo handler: + +.. code:: python + + @webhook_handler("echo") + def webhook_echo(payload, request): + return {"received": payload, "method": request.method} + +A handler that returns a status code together with a JSON body: + +.. code:: python + + @webhook_handler("create_thing") + def webhook_create(payload): + if "name" not in payload: + return 400, {"error": "name is required"} + thing_id = make_thing(payload["name"]) + return 201, {"id": thing_id, "status": "created"} + +For full control over the response, return an ``aiohttp.web.Response``: + +.. code:: python + + from aiohttp.web import Response + + @webhook_handler("redirect") + def webhook_redirect(): + return Response(status=302, headers={"Location": "https://example.com/"}) + @state_active ^^^^^^^^^^^^^ diff --git a/tests/conftest.py b/tests/conftest.py index 8a3a4ff..bf56114 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,227 @@ """Test: Global configuration for pytest.""" +from ast import literal_eval +import asyncio +from collections.abc import Generator +from datetime import datetime +import re +from typing import Any +from unittest.mock import patch + +from mock_open import MockOpen import pytest +from custom_components.pyscript import trigger +from custom_components.pyscript.const import DOMAIN, FOLDER +from custom_components.pyscript.eval import AstEval +from custom_components.pyscript.function import Function +from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component + @pytest.fixture(autouse=True) def auto_enable_custom_integrations(enable_custom_integrations): """Enable custom integrations in all tests.""" yield + + +_UNSET = object() + + +async def _wait(queue: asyncio.Queue, expected: Any, timeout: float) -> Any: + raw = await asyncio.wait_for(queue.get(), timeout=timeout) + if expected is _UNSET: + return raw + actual = raw if isinstance(expected, str) else literal_eval(raw) + assert actual == expected + return actual + + +class PyscriptFixture: + """ + Configurable pyscript bootstrap for tests. + + Attributes: + files: Full path -> source content. Multiple entries enable + pattern-based glob mocking for apps/modules/scripts layouts. + Use `add_file` / `add_files` to populate. + config: Config passed to `async_setup_component` (defaults to + `{DOMAIN: {}}`). + yaml_config: Value returned by mocked `load_yaml_config_file` + (defaults to `{}`). + now: A single datetime, or a list of datetimes returned successively + as pyscript reads `dt_now`. + done, done2: Queues fed by `pyscript.done` / `pyscript.done2` state + writes from scripts. Consume via `wait_done` / `wait_done2`. + exceptions: Queue of exceptions captured from `AstEval.log_exception` + (i.e. any exception pyscript would otherwise only surface via the + log). Consume via `wait_exception`. + + """ + + DEFAULT_NOW = datetime(2020, 7, 1, 11, 59, 59, 999999) + + def __init__(self, hass: HomeAssistant, monkeypatch: pytest.MonkeyPatch | None = None) -> None: + """Initialize with defaults; mutate attributes before calling `start()`.""" + self.hass = hass + self.monkeypatch = monkeypatch + self.files: dict[str, str] = {} + self.config: dict[str, Any] | None = None + self.yaml_config: dict[str, Any] | None = None + self.now: datetime | list[datetime] = self.DEFAULT_NOW + self.done: asyncio.Queue = asyncio.Queue() + self.done2: asyncio.Queue = asyncio.Queue() + self.exceptions: asyncio.Queue = asyncio.Queue() + + def add_file(self, name: str, content: str) -> None: + """Register a single script file under the pyscript folder (name is relative to FOLDER).""" + self.files[f"{self.hass.config.path(FOLDER)}/{name}"] = content + + def add_files(self, files: dict[str, str]) -> None: + """Register multiple script files under the pyscript folder (keys are relative to FOLDER).""" + for name, content in files.items(): + self.add_file(name, content) + + async def start(self, source: str | None = None) -> None: + """ + Load pyscript using the current attributes. + + If `source` is given, it's added as `hello.py` before loading — a + shortcut for the common single-file test pattern. + """ + if source is not None: + self.add_file("hello.py", source) + Function.hass = None + + if self.monkeypatch is not None: + original_log_exception = AstEval.log_exception + exceptions_queue = self.exceptions + + def capturing_log_exception(ast_self, exc): + exceptions_queue.put_nowait(exc) + original_log_exception(ast_self, exc) + + self.monkeypatch.setattr(AstEval, "log_exception", capturing_log_exception) + + config = self.config if self.config is not None else {DOMAIN: {}} + yaml_config = self.yaml_config if self.yaml_config is not None else {} + files_map = self.files + now = self.now + first_value = now[0] if isinstance(now, list) else now + + mock_open = MockOpen() + for path, content in files_map.items(): + mock_open[path].read_data = content + + def isfile_side_effect(arg): + return arg in files_map + + def glob_side_effect(path, recursive=None, root_dir=None, dir_fd=None, include_hidden=False): + result = [] + path_re = path.replace("*", "[^/]*").replace(".", "\\.") + path_re = path_re.replace("[^/]*[^/]*/", ".*") + for this_path in files_map: + if re.match(path_re, this_path): + result.append(this_path) + return result + + with ( + patch("custom_components.pyscript.os.path.isdir", return_value=True), + patch("custom_components.pyscript.os_path_isdir", return_value=True), + patch("custom_components.pyscript.glob.iglob") as mock_glob, + patch("custom_components.pyscript.global_ctx.open", mock_open), + patch("custom_components.pyscript.open", mock_open), + patch("custom_components.pyscript.trigger.dt_now", return_value=first_value), + patch("homeassistant.config.load_yaml_config_file", return_value=yaml_config), + patch("custom_components.pyscript.watchdog_start", return_value=None), + patch("custom_components.pyscript.os.path.getmtime", return_value=1000), + patch("custom_components.pyscript.global_ctx.os.path.getmtime", return_value=1000), + patch("custom_components.pyscript.install_requirements", return_value=None), + patch("custom_components.pyscript.global_ctx.os_path_isfile") as mock_isfile, + ): + mock_isfile.side_effect = isfile_side_effect + mock_glob.side_effect = glob_side_effect + assert await async_setup_component(self.hass, "pyscript", config) + + def return_next_time(): + if isinstance(now, list): + return now.pop(0) if len(now) > 1 else now[0] + return now + + trigger.__dict__["dt_now"] = return_next_time + + async def state_changed(event): + entity_id = event.data["entity_id"] + if entity_id == "pyscript.done": + await self.done.put(event.data["new_state"].state) + elif entity_id == "pyscript.done2": + await self.done2.put(event.data["new_state"].state) + + self.hass.bus.async_listen(EVENT_STATE_CHANGED, state_changed) + + self.hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await self.hass.async_block_till_done() + + async def wait_done(self, expected: Any = _UNSET, *, timeout: float = 4) -> Any: + """ + Await the next pyscript.done value. + + If `expected` is given, assert the received value matches and return it. + Non-string expected are compared after `ast.literal_eval` (pyscript + stores list/dict/number values as string representations in HA state). + """ + return await _wait(self.done, expected, timeout) + + async def wait_done2(self, expected: Any = _UNSET, *, timeout: float = 4) -> Any: + """Await the next pyscript.done2 value (see wait_done).""" + return await _wait(self.done2, expected, timeout) + + async def wait_exception( + self, + expected_type: type[BaseException] | None = None, + *, + match: str | None = None, + timeout: float = 4, + ) -> BaseException: + """ + Await the next exception captured from `AstEval.log_exception`. + + If `expected_type` is given, assert the captured exception is an + instance of it. If `match` is given, assert it appears in `str(exc)`. + """ + exc = await asyncio.wait_for(self.exceptions.get(), timeout=timeout) + if expected_type is not None: + assert isinstance(exc, expected_type), ( + f"expected {expected_type.__name__}, got {type(exc).__name__}: {exc}" + ) + if match is not None: + assert match in str(exc), f"expected {match!r} in {str(exc)!r}" + return exc + + +def _drain(queue: asyncio.Queue) -> list[Any]: + leftover = [] + while not queue.empty(): + leftover.append(queue.get_nowait()) + return leftover + + +@pytest.fixture +def pyscript(hass: HomeAssistant, monkeypatch: pytest.MonkeyPatch) -> Generator[PyscriptFixture]: + """ + Per-test pyscript fixture: configure attributes, then `await pyscript.start()`. + + On teardown, asserts that no ``pyscript.done`` / ``pyscript.done2`` value and + no exception captured from ``AstEval.log_exception`` was left unconsumed. + Tests that intentionally produce values or exceptions must drain them via + ``wait_done`` / ``wait_done2`` / ``wait_exception``. + """ + fixture = PyscriptFixture(hass, monkeypatch) + yield fixture + leftover_done = _drain(fixture.done) + leftover_done2 = _drain(fixture.done2) + leftover_exc = _drain(fixture.exceptions) + assert not leftover_done, f"unconsumed pyscript.done values: {leftover_done}" + assert not leftover_done2, f"unconsumed pyscript.done2 values: {leftover_done2}" + assert not leftover_exc, f"unconsumed pyscript exceptions: {leftover_exc}" diff --git a/tests/decorators/test_webhook.py b/tests/decorators/test_webhook.py new file mode 100644 index 0000000..cab7f79 --- /dev/null +++ b/tests/decorators/test_webhook.py @@ -0,0 +1,353 @@ +"""Test pyscript @webhook_trigger and @webhook_handler decorators.""" + +import asyncio +import json + +import pytest + +from homeassistant.components import webhook +from homeassistant.util.aiohttp import MockRequest + + +def _request( + *, + body: bytes = b"", + method: str = "POST", + headers: dict[str, str] | None = None, +) -> MockRequest: + return MockRequest( + content=body, + mock_source="test", + method=method, + headers=headers or {}, + remote="127.0.0.1", + ) + + +@pytest.mark.asyncio +async def test_webhook_request_kwarg(pyscript, hass): + """The aiohttp request is passed to the user function as the `request` kwarg.""" + await pyscript.start(""" +@webhook_trigger("test_req_hook") +def webhook_test(payload, request): + pyscript.done = [request.headers["X-My-Sig"], request.method, payload] +""") + + request = _request( + body=b'{"hello": "world"}', + headers={"Content-Type": "application/json", "X-My-Sig": "abc123"}, + ) + await webhook.async_handle_webhook(hass, "test_req_hook", request) + + await pyscript.wait_done(["abc123", "POST", {"hello": "world"}]) + + +@pytest.mark.asyncio +async def test_webhook_methods_order(pyscript): + """Same webhook_id with methods listed in a different order is not a conflict.""" + await pyscript.start(""" +@webhook_trigger("order_hook", methods=["GET", "POST"]) +def func_order_a(): + pass + +@webhook_trigger("order_hook", methods=["POST", "GET"]) +def func_order_b(): + pass +""") + assert pyscript.exceptions.empty() + + +@pytest.mark.asyncio +async def test_webhooks_method(pyscript): + """Test invalid keyword arguments type generates an error.""" + await pyscript.start(""" +@webhook_trigger("hook", methods=["bad"]) +def func8(): + pass +""") + await pyscript.wait_exception(TypeError, match="func8' defined in file.hello") + + +@pytest.mark.asyncio +async def test_webhook_local_only_conflict(pyscript): + """Test @webhook_trigger with same webhook_id but conflicting local_only raises.""" + await pyscript.start(""" +@webhook_trigger("conflict_local", local_only=True) +def func_local_a(): + pass + +@webhook_trigger("conflict_local", local_only=False) +def func_local_b(): + pass +""") + await pyscript.wait_exception(ValueError, match="'conflict_local' conflicts with existing") + + +@pytest.mark.asyncio +async def test_webhook_methods_conflict(pyscript): + """Test @webhook_trigger with same webhook_id but conflicting methods raises.""" + await pyscript.start(""" +@webhook_trigger("conflict_methods", methods=["GET"]) +def func_methods_a(): + pass + +@webhook_trigger("conflict_methods", methods=["POST"]) +def func_methods_b(): + pass +""") + await pyscript.wait_exception(ValueError, match="'conflict_methods' conflicts with existing") + + +@pytest.mark.asyncio +async def test_webhook_methods_missing_vs_set_conflict(pyscript): + """Test @webhook_trigger with same webhook_id but only one specifying methods raises.""" + await pyscript.start(""" +@webhook_trigger("conflict_unset") +def func_unset_a(): + pass + +@webhook_trigger("conflict_unset", methods=["POST"]) +def func_unset_b(): + pass +""") + await pyscript.wait_exception(ValueError, match="'conflict_unset' conflicts with existing") + + +# --- @webhook_handler --- + + +@pytest.mark.asyncio +async def test_webhook_handler_responses(pyscript, hass): + """Exercise the @webhook_handler return-value mapping and error paths end-to-end.""" + await pyscript.start(""" +@webhook_handler("empty") +def webhook_empty(**_): + pass + +@webhook_handler("status_with_body") +def webhook_status_with_body(**_): + return (201, {"id": 7}) + +@webhook_handler("status_with_none") +def webhook_status_with_none(**_): + return (204, None) + +@webhook_handler("text_hook") +def webhook_text(**_): + return "hello" + +@webhook_handler("json_hook") +def webhook_json(**_): + return {"hello": "world"} + +@webhook_handler("bad_json") +def webhook_bad_json(**_): + # complex is not JSON-serializable + return {"value": 1+2j} + +@webhook_handler("bytes_hook") +def webhook_bytes(**_): + return b"\\x00\\x01raw" + +@webhook_handler("unsupported") +def webhook_unsupported(**_): + return 42 + +@webhook_handler("bad_tuple") +def webhook_bad_tuple(**_): + # 2-tuple but the first element isn't a status int + return ("ok", "body") + +@webhook_handler("crash") +def webhook_crash(**_): + raise ValueError("boom") + +@webhook_handler("parse") +def webhook_parse(**_): + pyscript.done = "should not run" + return (200, None) + +@webhook_handler("req_handler") +def webhook_req(payload, request, **_): + return {"sig": request.headers["X-My-Sig"], "method": request.method, "payload": payload} +""") + + # no return -> 200 + response = await webhook.async_handle_webhook(hass, "empty", _request()) + assert response.status == 200 + + # (status, body) tuple -> status from tuple, body rendered as JSON + response = await webhook.async_handle_webhook(hass, "status_with_body", _request()) + assert response.status == 201 + assert response.content_type == "application/json" + assert json.loads(response.text) == {"id": 7} + + # (status, None) tuple -> status from tuple, empty body + response = await webhook.async_handle_webhook(hass, "status_with_none", _request()) + assert response.status == 204 + + # str -> 200 with text body + response = await webhook.async_handle_webhook(hass, "text_hook", _request()) + assert response.status == 200 + assert response.text == "hello" + + # dict -> 200 with JSON body + response = await webhook.async_handle_webhook(hass, "json_hook", _request()) + assert response.status == 200 + assert response.content_type == "application/json" + assert json.loads(response.text) == {"hello": "world"} + + # non-JSON-serializable dict -> 500, also surfaced via log_exception + response = await webhook.async_handle_webhook(hass, "bad_json", _request()) + assert response.status == 500 + await pyscript.wait_exception(TypeError, match="not JSON serializable") + + # bytes -> 200 with raw body + response = await webhook.async_handle_webhook(hass, "bytes_hook", _request()) + assert response.status == 200 + assert response.body == b"\x00\x01raw" + + # unsupported return type (e.g. bare int) -> 500 + response = await webhook.async_handle_webhook(hass, "unsupported", _request()) + assert response.status == 500 + + # 2-tuple with non-int status -> falls through to unsupported -> 500 + response = await webhook.async_handle_webhook(hass, "bad_tuple", _request()) + assert response.status == 500 + + # uncaught exception -> 500, also surfaced via log_exception + response = await webhook.async_handle_webhook(hass, "crash", _request()) + assert response.status == 500 + await pyscript.wait_exception(match="boom") + + # malformed JSON -> 400 without invoking the function + response = await webhook.async_handle_webhook( + hass, + "parse", + _request(body=b"{not json", headers={"Content-Type": "application/json"}), + ) + assert response.status == 400 + assert pyscript.done.empty() + + # payload / request forwarded; return value becomes the response + response = await webhook.async_handle_webhook( + hass, + "req_handler", + _request( + body=b'{"hello": "world"}', + headers={"Content-Type": "application/json", "X-My-Sig": "abc123"}, + ), + ) + assert response.status == 200 + assert json.loads(response.text) == { + "sig": "abc123", + "method": "POST", + "payload": {"hello": "world"}, + } + + +@pytest.mark.asyncio +async def test_webhook_handler_concurrent_requests(pyscript, hass): + """ + Two in-flight requests must each receive their own response with no crosstalk. + + The slow request is launched first but the fast one finishes first, so a + naive implementation that stored the response future on the decorator + instance instead of per-DispatchData would route the fast result to the + slow caller (or vice versa). + """ + await pyscript.start(""" +@webhook_handler("concurrent") +def webhook_concurrent(payload, **_): + task.sleep(payload["sleep"]) + return {"echo": payload["id"]} +""") + + slow, fast = await asyncio.gather( + webhook.async_handle_webhook( + hass, + "concurrent", + _request( + body=b'{"id": "slow", "sleep": 0.1}', + headers={"Content-Type": "application/json"}, + ), + ), + webhook.async_handle_webhook( + hass, + "concurrent", + _request( + body=b'{"id": "fast", "sleep": 0.02}', + headers={"Content-Type": "application/json"}, + ), + ), + ) + + assert slow.status == 200 + assert fast.status == 200 + assert json.loads(slow.text) == {"echo": "slow"} + assert json.loads(fast.text) == {"echo": "fast"} + + +@pytest.mark.parametrize("expected_lingering_tasks", [True]) +@pytest.mark.asyncio +async def test_webhook_handler_timeout(pyscript, hass): + """ + A function that doesn't finish in time -> 504 Gateway Timeout. + + The pyscript task running the function lingers past the test (it's still + sleeping when the webhook returns), so we opt in to HA's lingering-task + allowance for this test only. + """ + await pyscript.start(""" +@webhook_handler("slow", timeout=0.05) +def webhook_slow(**_): + task.sleep(0.15) + return 200 +""") + response = await webhook.async_handle_webhook(hass, "slow", _request()) + assert response.status == 504 + + +@pytest.mark.asyncio +async def test_webhook_handler_expression_rejects(pyscript, hass): + """A str_expr that evaluates falsy -> 403 Forbidden, function is not invoked.""" + await pyscript.start(""" +@webhook_handler("guarded", "payload.get('token') == 'secret'") +def webhook_guarded(payload, **_): + pyscript.done = payload["token"] + return {"ok": True} +""") + + # token missing -> guard rejects -> 403, function is not invoked + response = await webhook.async_handle_webhook( + hass, + "guarded", + _request(body=b'{"token": "wrong"}', headers={"Content-Type": "application/json"}), + ) + assert response.status == 403 + assert pyscript.done.empty() + + # token matches -> function runs + response = await webhook.async_handle_webhook( + hass, + "guarded", + _request(body=b'{"token": "secret"}', headers={"Content-Type": "application/json"}), + ) + assert response.status == 200 + assert json.loads(response.text) == {"ok": True} + await pyscript.wait_done("secret") + + +@pytest.mark.asyncio +async def test_webhook_handler_duplicate_id(pyscript): + """Two @webhook_handler with the same id conflict at registration.""" + await pyscript.start(""" +@webhook_handler("dup") +def webhook_dup_a(**_): + pass + +@webhook_handler("dup") +def webhook_dup_b(**_): + pass +""") + await pyscript.wait_exception(ValueError, match="Handler is already defined") diff --git a/tests/test_decorator_errors.py b/tests/test_decorator_errors.py index b03d8cc..94879ae 100644 --- a/tests/test_decorator_errors.py +++ b/tests/test_decorator_errors.py @@ -519,20 +519,3 @@ def func7(): "TypeError: function 'func7' defined in file.hello: decorator @state_trigger keyword 'watch' should be type list or set" in caplog.text ) - - -@pytest.mark.asyncio -async def test_webhooks_method(hass, caplog): - """Test invalid keyword arguments type generates an error.""" - - await setup_script( - hass, - None, - dt(2020, 7, 1, 11, 59, 59, 999999), - """ -@webhook_trigger("hook", methods=["bad"]) -def func8(): - pass -""", - ) - assert "TypeError: function 'func8' defined in file.hello:" in caplog.text diff --git a/tests/test_decorator_manager.py b/tests/test_decorator_manager.py index 45c6f89..bd633bf 100644 --- a/tests/test_decorator_manager.py +++ b/tests/test_decorator_manager.py @@ -2,8 +2,9 @@ from __future__ import annotations +from collections.abc import Awaitable import logging -from typing import ClassVar +from typing import Any, ClassVar from unittest.mock import patch import pytest @@ -22,6 +23,7 @@ DecoratorManager, DecoratorManagerStatus, DispatchData, + TriggerHandlerDecorator, ) import custom_components.pyscript.decorators.base as decorators_base_module from custom_components.pyscript.decorators.base import AutoKwargsDecorator, ExpressionDecorator @@ -146,6 +148,39 @@ async def handle_call_result(self, data: DispatchData, result: object) -> None: self.results.append(result) +class FullRecordingResultHandler(CallResultHandlerDecorator): + """Result handler that records all three notification methods separately.""" + + name = "full_record_result" + results: list[object] + exceptions: list[Exception] + canceled_calls: int + + async def handle_call_result(self, data: DispatchData, result: object) -> None: + """Record successful result.""" + self.results.append(result) + + async def handle_call_exception(self, data: DispatchData, exc: Exception) -> None: + """Record exception (override the default that forwards as None).""" + self.exceptions.append(exc) + + async def handle_call_canceled(self, data: DispatchData) -> None: + """Record cancellation (override the default that forwards as None).""" + self.canceled_calls += 1 + + +class CancelDispatchHandler(TriggerHandlerDecorator): + """Trigger handler that cancels the dispatch.""" + + name = "cancel_dispatch" + seen: list[dict] + + async def handle_dispatch(self, data: DispatchData) -> bool: + """Cancel the dispatch.""" + self.seen.append(data.func_args.copy()) + return False + + class AutoKwargsTestDecorator(AutoKwargsDecorator): """Decorator used to test AutoKwargsDecorator behavior.""" @@ -208,6 +243,22 @@ def make_recording_result_handler() -> RecordingResultHandler: return handler +def make_full_recording_result_handler() -> FullRecordingResultHandler: + """Create a full recording result handler.""" + handler = FullRecordingResultHandler([], {}) + handler.results = [] + handler.exceptions = [] + handler.canceled_calls = 0 + return handler + + +def make_cancel_dispatch_handler() -> CancelDispatchHandler: + """Create a dispatch-canceling trigger handler.""" + handler = CancelDispatchHandler([], {}) + handler.seen = [] + return handler + + def get_registry_decorators(default: object | None = None) -> object | None: """Return the decorator registry mapping.""" return getattr(DecoratorRegistry, _REGISTRY_ATTR, default) @@ -309,6 +360,10 @@ async def stop(self) -> None: """Record manager stop.""" self.stop_calls += 1 + async def safe_await(self, coro: Awaitable[Any]) -> None: + """Await ``coro`` without error handling (sufficient for these tests).""" + await coro + class FakeFunctionDecoratorManager: """Patchable manager stub for GlobalContext.create_decorator_manager tests.""" @@ -347,6 +402,10 @@ async def stop(self) -> None: """Record manager stop.""" self.stop_calls += 1 + # Reuse the real implementations — they only need self.ast_ctx. + safe_await = DecoratorManager.safe_await + handle_exception = DecoratorManager.handle_exception + def make_dispatch_data( func_args: dict[str, object], @@ -381,6 +440,40 @@ async def test_decorator_manager_no_decorators_and_accessors(): await dm.start() +@pytest.mark.asyncio +async def test_decorator_manager_safe_await_returns_silently_on_success(): + """safe_await should await without surfacing anything when the coroutine succeeds.""" + ast_ctx = DummyAstCtx() + dm = DummyManager(ast_ctx) + awaited = [] + + async def work(): + awaited.append("ran") + return 42 # value is intentionally discarded by safe_await + + result = await dm.safe_await(work()) + + assert result is None + assert awaited == ["ran"] + assert not ast_ctx.logged_exceptions + + +@pytest.mark.asyncio +async def test_decorator_manager_safe_await_routes_exception_through_handle_exception(): + """safe_await should catch the coroutine's exception and forward it to handle_exception.""" + ast_ctx = DummyAstCtx() + dm = DummyManager(ast_ctx) + boom = RuntimeError("boom") + + async def work(): + raise boom + + # Must not raise. + await dm.safe_await(work()) + + assert ast_ctx.logged_exceptions == [boom] + + @pytest.mark.asyncio async def test_decorator_manager_start_rolls_back_started_decorators(): """A later start failure should stop already-started decorators.""" @@ -599,6 +692,116 @@ async def test_function_decorator_manager_logs_call_exception(hass): assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed" +@pytest.mark.asyncio +async def test_function_decorator_manager_exception_calls_handle_call_exception(hass): + """On exception path, result handlers should receive handle_call_exception (not handle_call_result).""" + DecoratorManager.hass = hass + ast_ctx = DummyAstCtx() + manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + result_handler = make_full_recording_result_handler() + manager.add(result_handler) + exc = RuntimeError("boom") + call_ast_ctx = DummyCallAstCtx(exc=exc) + + await call_function_manager( + manager, + make_dispatch_data({"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=Context(id="call-parent")), + ) + + assert result_handler.exceptions == [exc] + assert not result_handler.results + assert result_handler.canceled_calls == 0 + # exception is still logged via the manager + assert ast_ctx.logged_exceptions == [exc] + + +@pytest.mark.asyncio +async def test_function_decorator_manager_cancel_calls_handle_call_canceled(hass): + """On CallHandler veto, result handlers should receive handle_call_canceled (not handle_call_result).""" + DecoratorManager.hass = hass + manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) + call_handler = make_cancel_call_handler() + result_handler = make_full_recording_result_handler() + call_ast_ctx = DummyCallAstCtx(result="unused") + manager.add(call_handler) + manager.add(result_handler) + + await call_function_manager( + manager, + make_dispatch_data({"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=Context(id="call-parent")), + ) + + assert call_handler.seen == [{"arg1": 1}] + assert result_handler.canceled_calls == 1 + assert not result_handler.results + assert not result_handler.exceptions + assert not call_ast_ctx.calls + + +@pytest.mark.asyncio +async def test_function_decorator_manager_dispatch_veto_calls_handle_call_canceled(hass): + """On TriggerHandler veto, result handlers should receive handle_call_canceled.""" + DecoratorManager.hass = hass + manager = FunctionDecoratorManager(DummyAstCtx(), DummyEvalFuncVar()) + trigger_handler = make_cancel_dispatch_handler() + result_handler = make_full_recording_result_handler() + manager.add(trigger_handler) + manager.add(result_handler) + + await manager.dispatch(make_dispatch_data({"arg1": 1})) + + assert trigger_handler.seen == [{"arg1": 1}] + assert result_handler.canceled_calls == 1 + assert not result_handler.results + assert not result_handler.exceptions + + +@pytest.mark.asyncio +async def test_function_decorator_manager_result_handler_failure_does_not_block_siblings(hass): + """A buggy result handler should be routed via handle_exception and not stop siblings.""" + DecoratorManager.hass = hass + ast_ctx = DummyAstCtx() + manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + + handler_bug = RuntimeError("broken handler boom") + + class _BrokenHandler(CallResultHandlerDecorator): + name = "broken_handler" + + async def handle_call_result(self, data: DispatchData, result: object) -> None: + raise handler_bug + + broken = _BrokenHandler([], {}) + sibling = make_full_recording_result_handler() + call_ast_ctx = DummyCallAstCtx(result="ok") + manager.add(broken) + manager.add(sibling) + + await call_function_manager( + manager, + make_dispatch_data({"arg1": 1}, call_ast_ctx=call_ast_ctx, hass_context=Context(id="cid")), + ) + + assert sibling.results == ["ok"] + assert ast_ctx.logged_exceptions == [handler_bug] + + +@pytest.mark.asyncio +async def test_call_result_handler_default_handle_call_exception_forwards_none(): + """Default handle_call_exception should forward the call to handle_call_result with None.""" + handler = make_recording_result_handler() + await handler.handle_call_exception(make_dispatch_data({}), RuntimeError("boom")) + assert handler.results == [None] + + +@pytest.mark.asyncio +async def test_call_result_handler_default_handle_call_canceled_forwards_none(): + """Default handle_call_canceled should forward the call to handle_call_result with None.""" + handler = make_recording_result_handler() + await handler.handle_call_canceled(make_dispatch_data({})) + assert handler.results == [None] + + def test_decorator_registry_register_requires_name(): """Registry should reject decorators without a declared name.""" diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 12224d4..8350500 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -10,10 +10,8 @@ from custom_components.pyscript import trigger from custom_components.pyscript.const import DOMAIN from custom_components.pyscript.function import Function -from homeassistant.components import webhook from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED from homeassistant.setup import async_setup_component -from homeassistant.util.aiohttp import MockRequest async def setup_script(hass, notify_q, now, source): @@ -226,33 +224,3 @@ def func6(value): hass.states.async_set("pyscript.var1", 6 + 2 * i) seq_num += 1 assert literal_eval(await wait_until_done(notify_q)) == [seq_num, 6 + 2 * i] - - -@pytest.mark.asyncio -async def test_webhook_request_kwarg(hass): - """The aiohttp request is passed to the user function as the `request` kwarg.""" - notify_q = asyncio.Queue(0) - await setup_script( - hass, - notify_q, - [dt(2020, 7, 1, 11, 59, 59, 999999)], - """ -@webhook_trigger("test_req_hook") -def webhook_test(payload, request): - pyscript.done = [request.headers["X-My-Sig"], request.method, payload] -""", - ) - hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) - await hass.async_block_till_done() - - request = MockRequest( - content=b'{"hello": "world"}', - mock_source="test", - method="POST", - headers={"Content-Type": "application/json", "X-My-Sig": "abc123"}, - remote="127.0.0.1", - ) - - await webhook.async_handle_webhook(hass, "test_req_hook", request) - - assert literal_eval(await wait_until_done(notify_q)) == ["abc123", "POST", {"hello": "world"}]