diff --git a/packages/reflex-base/src/reflex_base/environment.py b/packages/reflex-base/src/reflex_base/environment.py index f2d4ce44cb9..956cd00d6cd 100644 --- a/packages/reflex-base/src/reflex_base/environment.py +++ b/packages/reflex-base/src/reflex_base/environment.py @@ -611,6 +611,10 @@ class EnvironmentVariables: # If this env var is set to "yes", App.compile will be a no-op REFLEX_SKIP_COMPILE: EnvVar[bool] = env_var(False, internal=True) + # Inherited by uvicorn/granian reload workers so the backend can distinguish + # dev reload-capable worker boots from other backend starts. Never set in prod. + REFLEX_DEV_BACKEND_RELOAD_ACTIVE: EnvVar[bool] = env_var(False, internal=True) + # Whether to run app harness tests in headless mode. APP_HARNESS_HEADLESS: EnvVar[bool] = env_var(False) diff --git a/pyi_hashes.json b/pyi_hashes.json index 7f5883643b7..29a69420957 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -120,5 +120,5 @@ "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "2c5fadcc014056f041cd4d916137d9e7", "reflex/__init__.pyi": "3a9bb8544cbc338ffaf0a5927d9156df", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", - "reflex/experimental/memo.pyi": "82d8699470071df80886a4a6ba8dccfe" + "reflex/experimental/memo.pyi": "d09629b81bf0df6153b131ac0ee10bd7" } diff --git a/reflex/app.py b/reflex/app.py index 65ae0a8235b..bd99278c64e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -70,6 +70,7 @@ from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.compiler import compiler from reflex.compiler.compiler import readable_name_from_component +from reflex.istate.data import RouterData from reflex.istate.manager import StateManager, StateModificationContext from reflex.istate.manager.token import BaseStateToken from reflex.page import DECORATED_PAGES @@ -78,21 +79,24 @@ replace_brackets_with_keywords, verify_route_validity, ) -from reflex.state import ( - BaseState, - RouterData, - State, - StateUpdate, - all_base_state_classes, +from reflex.state import BaseState, State, StateUpdate, all_base_state_classes +from reflex.utils import ( + codespaces, + exceptions, + format, + js_runtimes, + prerequisites, + telemetry_accounting, ) -from reflex.utils import codespaces, exceptions, format, js_runtimes, prerequisites from reflex.utils.exec import ( + get_backend_compile_trigger, get_compile_context, is_prod_mode, is_testing_env, should_prerender_routes, ) from reflex.utils.misc import run_in_thread +from reflex.utils.telemetry_context import CompileTrigger, TelemetryContext from reflex.utils.token_manager import RedisTokenManager, TokenManager if sys.version_info < (3, 13): @@ -662,7 +666,10 @@ def __call__(self) -> ASGIApp: # rx.asset(shared=True) symlink re-creation doesn't trigger further reloads. remove_stale_external_asset_symlinks() - self._compile(prerender_routes=should_prerender_routes()) + self._compile( + prerender_routes=should_prerender_routes(), + trigger=get_backend_compile_trigger(), + ) config = get_config() @@ -1167,6 +1174,7 @@ def _compile( prerender_routes: bool = False, dry_run: bool = False, use_rich: bool = True, + trigger: CompileTrigger | None = None, ): """Compile the app and output it to the pages folder. @@ -1174,17 +1182,39 @@ def _compile( prerender_routes: Whether to prerender the routes. dry_run: Whether to compile the app without saving it. use_rich: Whether to use rich progress bars. + trigger: Label identifying what initiated this compile. Recorded + on the ``compile`` telemetry event. Raises: ReflexRuntimeError: When any page uses state, but no rx.State subclass is defined. FileNotFoundError: When a plugin requires a file that does not exist. """ - compiler.compile_app( - self, - prerender_routes=prerender_routes, - dry_run=dry_run, - use_rich=use_rich, - ) + ctx = TelemetryContext.start(trigger=trigger) + if ctx is None: + compiler.compile_app( + self, + prerender_routes=prerender_routes, + dry_run=dry_run, + use_rich=use_rich, + ) + return + + with ctx: + did_real_compile = False + try: + did_real_compile = compiler.compile_app( + self, + prerender_routes=prerender_routes, + dry_run=dry_run, + use_rich=use_rich, + ) + except Exception as exc: + ctx.set_exception(exc) + did_real_compile = True + raise + finally: + if did_real_compile: + telemetry_accounting.record_compile(self, ctx) def _write_stateful_pages_marker(self): """Write list of routes that create dynamic states for the backend to use later.""" diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 6c986bd1903..a0a934c1b6d 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -993,8 +993,13 @@ def compile_app( prerender_routes: bool = False, dry_run: bool = False, use_rich: bool = True, -) -> None: - """Compile an app using the compiler plugin pipeline.""" +) -> bool: + """Compile an app using the compiler plugin pipeline. + + Returns: + ``True`` when a real frontend compile ran, ``False`` when the call + short-circuited (backend-only paths that only re-evaluate pages). + """ from reflex_base.components.dynamic import bundle_library, reset_bundled_libraries from reflex_base.utils.exceptions import ReflexRuntimeError @@ -1012,7 +1017,7 @@ def compile_app( console.debug(f"BE Evaluating stateful page: {route}") app._compile_page(route, save_page=False) app._add_optional_endpoints() - return + return False if constants.Page404.SLUG not in app._unevaluated_pages: app.add_page(route=constants.Page404.SLUG) @@ -1028,7 +1033,7 @@ def compile_app( app._write_stateful_pages_marker() app._add_optional_endpoints() - return + return False progress = ( Progress( @@ -1222,7 +1227,7 @@ def add_save_task( progress.stop() if dry_run: - return + return True with console.timing("Install Frontend Packages"): app._get_frontend_packages(all_imports) @@ -1277,3 +1282,5 @@ def add_save_task( with console.timing("Write to Disk"): for output_path, code in output_mapping.items(): utils.write_file(output_path, code) + + return True diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 776e056cc50..b5a18e3f26c 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -7,7 +7,7 @@ from collections.abc import Callable from copy import copy from functools import cache, update_wrapper -from typing import Any, get_args, get_origin, get_type_hints +from typing import Any, ClassVar, get_args, get_origin, get_type_hints from reflex_base import constants from reflex_base.components.component import Component @@ -94,6 +94,12 @@ class ExperimentalMemoComponent(Component): library = f"$/{constants.Dirs.COMPONENTS_PATH}" _memoization_mode = MemoizationMode(disposition=MemoizationDisposition.NEVER) + # The user-authored component class this wrapper stands in for. Populated + # on the dynamic subclass by ``_get_experimental_memo_component_class`` so + # introspection (e.g. compile telemetry) can recover the underlying type + # without parsing the wrapper's auto-generated class name. + _wrapped_component_type: ClassVar[type[Component] | None] = None + def _validate_component_children(self, children: list[Component]) -> None: """Skip direct parent/child validation for memo wrapper instances. @@ -176,6 +182,7 @@ def _get_experimental_memo_component_class( # Per-file import paths give Vite distinct module boundaries per # memo, enabling actual code-split by page. "library": f"$/{constants.Dirs.COMPONENTS_PATH}/{export_name}", + "_wrapped_component_type": wrapped_component_type, } if ( wrapped_component_type._get_app_wrap_components diff --git a/reflex/reflex.py b/reflex/reflex.py index 43099e74332..7fd181f8568 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -139,6 +139,7 @@ def _compile_app(*, avoid_dirty_check: bool = True): kwargs = { "check_if_schema_up_to_date": True, "prerender_routes": exec.should_prerender_routes(), + "trigger": "initial", } # Granian fails if the app is already imported. @@ -485,7 +486,7 @@ def compile(dry: bool, rich: bool): _init(name=get_config().app_name) get_config(reload=True) starting_time = time.monotonic() - prerequisites.get_compiled_app(dry_run=dry, use_rich=rich) + prerequisites.get_compiled_app(dry_run=dry, use_rich=rich, trigger="cli_compile") elapsed_time = time.monotonic() - starting_time console.success(f"App compiled successfully in {elapsed_time:.3f} seconds.") diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 9e4b326da56..a7cb721f542 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import hashlib import importlib.util import json @@ -25,10 +26,52 @@ from reflex.utils import path_ops from reflex.utils.misc import get_module_path from reflex.utils.prerequisites import get_web_dir +from reflex.utils.telemetry_context import CompileTrigger # For uvicorn windows bug fix (#2335) frontend_process = None +DEV_BACKEND_RELOAD_MARKER = ".reflex_dev_backend_started" + + +def get_dev_backend_reload_marker() -> Path: + """Get the marker path for dev backend reload-capable worker starts. + + Returns: + The path to the reload marker. + """ + return get_web_dir() / DEV_BACKEND_RELOAD_MARKER + + +def reset_dev_backend_reload_marker() -> None: + """Remove the reload marker at the start of a fresh dev backend session.""" + with contextlib.suppress(OSError): + get_dev_backend_reload_marker().unlink(missing_ok=True) + + +def get_backend_compile_trigger() -> CompileTrigger: + """Determine the compile trigger and claim the dev backend reload marker. + + Atomically creates the marker so a failed first compile is still treated + as the first worker boot: the next worker (after the user fixes the + error) will see the marker and report ``hot_reload``. If the marker + cannot be created (e.g. permission error, missing parent dir), falls + back to ``backend_startup``. + + Returns: + ``"backend_startup"`` for non-dev startups and the first dev + reload-capable worker boot, ``"hot_reload"`` for subsequent boots. + """ + if not environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.get(): + return "backend_startup" + try: + os.close(os.open(get_dev_backend_reload_marker(), os.O_CREAT | os.O_EXCL)) + except FileExistsError: + return "hot_reload" + except OSError: + pass + return "backend_startup" + def get_package_json_and_hash(package_json_path: Path) -> tuple[PackageJson, str]: """Get the content of package.json and its hash. @@ -537,6 +580,9 @@ def run_uvicorn_backend(host: str, port: int, loglevel: LogLevel): """ import uvicorn + reset_dev_backend_reload_marker() + environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.set(True) + uvicorn.run( app=f"{get_app_instance()}", factory=True, @@ -588,6 +634,9 @@ def run_granian_backend(host: str, port: int, loglevel: LogLevel): from granian.server import Server as Granian from reflex_base.environment import _load_dotenv_from_env + reset_dev_backend_reload_marker() + environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.set(True) + granian_app = Granian( target=get_app_instance_from_file(), factory=True, diff --git a/reflex/utils/export.py b/reflex/utils/export.py index 8b55fc42a68..bda7310f8b9 100644 --- a/reflex/utils/export.py +++ b/reflex/utils/export.py @@ -62,7 +62,9 @@ def export( if frontend: # Ensure module can be imported and app.compile() is called. - prerequisites.get_compiled_app(prerender_routes=prerender_routes) + prerequisites.get_compiled_app( + prerender_routes=prerender_routes, trigger="export" + ) # Set up .web directory and install frontend dependencies. build.setup_frontend(Path.cwd()) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index af0e2c14ef6..027920e29f4 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -33,6 +33,7 @@ from redis.asyncio import Redis from reflex.app import App + from reflex.utils.telemetry_context import CompileTrigger class AppInfo(NamedTuple): @@ -251,6 +252,7 @@ def get_compiled_app( dry_run: bool = False, check_if_schema_up_to_date: bool = False, use_rich: bool = True, + trigger: CompileTrigger | None = None, ) -> ModuleType: """Get the app module based on the default config after first compiling it. @@ -260,6 +262,7 @@ def get_compiled_app( dry_run: If True, do not write the compiled app to disk. check_if_schema_up_to_date: If True, check if the schema is up to date. use_rich: Whether to use rich progress bars. + trigger: Optional label forwarded to ``App._compile`` for telemetry. Returns: The compiled app based on the default config. @@ -267,7 +270,12 @@ def get_compiled_app( app, app_module = get_and_validate_app( reload=reload, check_if_schema_up_to_date=check_if_schema_up_to_date ) - app._compile(prerender_routes=prerender_routes, dry_run=dry_run, use_rich=use_rich) + app._compile( + prerender_routes=prerender_routes, + dry_run=dry_run, + use_rich=use_rich, + trigger=trigger, + ) return app_module @@ -335,6 +343,7 @@ def compile_or_validate_app( compile: bool = False, check_if_schema_up_to_date: bool = False, prerender_routes: bool = False, + trigger: CompileTrigger | None = None, ) -> bool: """Compile or validate the app module based on the default config. @@ -342,6 +351,7 @@ def compile_or_validate_app( compile: Whether to compile the app. check_if_schema_up_to_date: If True, check if the schema is up to date. prerender_routes: Whether to prerender routes. + trigger: Optional label forwarded to ``App._compile`` for telemetry. Returns: True if the app was successfully compiled or validated, False otherwise. @@ -351,6 +361,7 @@ def compile_or_validate_app( get_compiled_app( check_if_schema_up_to_date=check_if_schema_up_to_date, prerender_routes=prerender_routes, + trigger=trigger, ) else: get_and_validate_app(check_if_schema_up_to_date=check_if_schema_up_to_date) diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py index dcdecd5e0a9..902a75d7f3b 100644 --- a/reflex/utils/telemetry.py +++ b/reflex/utils/telemetry.py @@ -9,7 +9,7 @@ import warnings from contextlib import suppress from datetime import datetime, timezone -from typing import TypedDict +from typing import Any, TypedDict, cast from reflex_base import constants from reflex_base.environment import environment @@ -271,12 +271,20 @@ def get_event_defaults() -> _DefaultEvent | None: return _get_event_defaults() -def _prepare_event(event: str, **kwargs) -> _Event | None: +def _prepare_event( + event: str, + *, + properties: dict[str, Any] | None = None, + **kwargs, +) -> _Event | None: """Prepare the event to be sent to the PostHog server. Args: event: The event name. - kwargs: Additional data to send with the event. + properties: Arbitrary structured payload merged into the event + properties. Preferred over ``kwargs`` for new events. + kwargs: Additional data to send with the event. Allow-listed keys + kept for backward compatibility with existing call sites. Returns: The event data. @@ -287,22 +295,29 @@ def _prepare_event(event: str, **kwargs) -> _Event | None: additional_keys = ["template", "context", "detail", "user_uuid"] - properties = event_data["properties"] + # Shallow-copy so we don't mutate the cached default properties dict. + merged_properties = dict(event_data["properties"]) for key in additional_keys: - if key in properties or key not in kwargs: + if key in merged_properties or key not in kwargs: continue - properties[key] = kwargs[key] + merged_properties[key] = kwargs[key] + + if properties: + merged_properties.update(properties) stamp = datetime.now(UTC).isoformat() - return { - "api_key": event_data["api_key"], - "event": event, - "properties": properties, - "timestamp": stamp, - } + return cast( + "_Event", + { + "api_key": event_data["api_key"], + "event": event, + "properties": merged_properties, + "timestamp": stamp, + }, + ) def _send_event(event_data: _Event) -> bool: @@ -316,7 +331,13 @@ def _send_event(event_data: _Event) -> bool: return True -def _send(event: str, telemetry_enabled: bool | None, **kwargs) -> bool: +def _send( + event: str, + telemetry_enabled: bool | None, + *, + properties: dict[str, Any] | None = None, + **kwargs, +) -> bool: from reflex_base.config import get_config # Get the telemetry_enabled from the config if it is not specified. @@ -328,7 +349,7 @@ def _send(event: str, telemetry_enabled: bool | None, **kwargs) -> bool: return False with suppress(Exception): - event_data = _prepare_event(event, **kwargs) + event_data = _prepare_event(event, properties=properties, **kwargs) if not event_data: return False return _send_event(event_data) @@ -338,22 +359,35 @@ def _send(event: str, telemetry_enabled: bool | None, **kwargs) -> bool: background_tasks = set() -def send(event: str, telemetry_enabled: bool | None = None, **kwargs): +def send( + event: str, + telemetry_enabled: bool | None = None, + *, + properties: dict[str, Any] | None = None, + **kwargs, +): """Send anonymous telemetry for Reflex. Args: event: The event name. telemetry_enabled: Whether to send the telemetry (If None, get from config). + properties: Arbitrary structured payload merged into the event + properties. Preferred over ``kwargs`` for new events. kwargs: Additional data to send with the event. """ - async def async_send(event: str, telemetry_enabled: bool | None, **kwargs): # noqa: RUF029 - return _send(event, telemetry_enabled, **kwargs) + async def async_send( # noqa: RUF029 + event: str, + telemetry_enabled: bool | None, + properties: dict[str, Any] | None, + **kwargs, + ): + return _send(event, telemetry_enabled, properties=properties, **kwargs) try: # Within an event loop context, send the event asynchronously. task = asyncio.create_task( - async_send(event, telemetry_enabled, **kwargs), + async_send(event, telemetry_enabled, properties, **kwargs), name=f"reflex_send_telemetry_event|{event}", ) background_tasks.add(task) @@ -361,7 +395,7 @@ async def async_send(event: str, telemetry_enabled: bool | None, **kwargs): # n except RuntimeError: # If there is no event loop, send the event synchronously. warnings.filterwarnings("ignore", category=RuntimeWarning) - _send(event, telemetry_enabled, **kwargs) + _send(event, telemetry_enabled, properties=properties, **kwargs) def send_error(error: Exception, context: str): diff --git a/reflex/utils/telemetry_accounting.py b/reflex/utils/telemetry_accounting.py new file mode 100644 index 00000000000..29f1971f3b7 --- /dev/null +++ b/reflex/utils/telemetry_accounting.py @@ -0,0 +1,193 @@ +"""Post-compile accounting helpers for the ``compile`` telemetry event.""" + +from __future__ import annotations + +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any, TypedDict + +from reflex_base.config import get_config +from reflex_base.utils import console + +from reflex.utils import telemetry + +if TYPE_CHECKING: + from reflex_base.components.component import BaseComponent + + from reflex.app import App + from reflex.state import BaseState + from reflex.utils.telemetry_context import CompileTrigger, TelemetryContext + + +class _StateStats(TypedDict): + """Per-state structural statistics.""" + + event_handlers_count: int + vars_count: int + backend_vars_count: int + computed_vars_count: int + depth_from_root: int + + +class _ExceptionInfo(TypedDict): + """Sanitized exception descriptor (class name only, no message).""" + + type: str + + +class _CompileEventProperties(TypedDict): + """Properties payload of the ``compile`` telemetry event.""" + + plugins_enabled: list[str] + plugins_disabled: list[str] + pages_count: int + component_counts: dict[str, int] + states: list[_StateStats] + features_used: dict[str, Any] + duration_ms: int + trigger: CompileTrigger | None + exception: _ExceptionInfo | None + + +def record_compile(app: App, ctx: TelemetryContext) -> None: + """Build the compile-event payload and send it to PostHog. + + Any exception from payload assembly or sending is swallowed and logged + so telemetry can never break a real compile. + + Args: + app: The compiled application. + ctx: The active telemetry context. + """ + try: + payload = _collect_compile_event_payload(app, ctx) + telemetry.send("compile", properties=dict(payload)) + except Exception as exc: + console.debug(f"compile telemetry event failed: {exc!r}") + + +def _collect_compile_event_payload( + app: App, ctx: TelemetryContext +) -> _CompileEventProperties: + """Build the properties dict sent with the ``compile`` PostHog event. + + Args: + app: The compiled application. + ctx: The active telemetry context. + + Returns: + The properties dict to send to PostHog. + """ + config = get_config() + return { + "plugins_enabled": [p.__class__.__name__ for p in config.plugins], + "plugins_disabled": [p.__name__ for p in config.disable_plugins], + "pages_count": len(app._pages), + "component_counts": _count_components(app._pages.values()), + "states": _collect_all_state_stats(app), + "features_used": dict(ctx.features_used), + "duration_ms": ctx.elapsed_ms(), + "trigger": ctx.trigger, + "exception": _sanitize_exception(ctx.exception), + } + + +def _count_components(pages: Iterable[BaseComponent]) -> dict[str, int]: + """Count component types across one or more component trees. + + Auto-memoized components live in the tree as dynamic + ``ExperimentalMemoComponent___`` subclasses. Bucketing by + the raw class name would explode telemetry cardinality (each handler hash + produces a new key), so wrappers are counted under the user-authored + component they stand in for, exposed via ``_wrapped_component_type``. + + Args: + pages: Component-tree roots to walk. + + Returns: + Mapping of component class name to occurrence count. + """ + counts: dict[str, int] = {} + stack: list[BaseComponent] = list(pages) + while stack: + node = stack.pop() + node_cls = type(node) + wrapped = getattr(node_cls, "_wrapped_component_type", None) + name = wrapped.__name__ if wrapped is not None else node_cls.__name__ + counts[name] = counts.get(name, 0) + 1 + if node.children: + stack.extend(node.children) + return counts + + +def _walk_states(root: type[BaseState] | None) -> Iterator[type[BaseState]]: + """Yield user-authored state classes reachable from ``root``. + + Framework-internal states (those whose module lives under ``reflex.``) are + skipped, but their user-defined subclasses are still yielded — ``SharedState`` + user classes hang off ``SharedStateBaseInternal``, so we descend through the + internal node rather than pruning the subtree. + + Args: + root: The root state class, or ``None`` when the app has no state. + + Yields: + Every user-authored state class reachable through ``get_substates()``. + """ + if root is None: + return + if root.is_user_defined(): + yield root + for sub in root.get_substates(): + yield from _walk_states(sub) + + +def _collect_all_state_stats(app: App) -> list[_StateStats]: + """Collect per-state statistics for every state attached to the app. + + Args: + app: The compiled application. + + Returns: + A list of per-state stat dicts. + """ + return [_collect_state_stats(state_cls) for state_cls in _walk_states(app._state)] + + +def _collect_state_stats(state_cls: type[BaseState]) -> _StateStats: + """Collect structural statistics for a single state class. + + Args: + state_cls: The state class to inspect. + + Returns: + A dict with field counts and depth-from-root. + """ + depth = 0 + parent = state_cls.get_parent_state() + while parent is not None: + depth += 1 + parent = parent.get_parent_state() + return { + "event_handlers_count": len(state_cls.event_handlers), + "vars_count": len(state_cls.vars), + "backend_vars_count": len(state_cls.backend_vars), + "computed_vars_count": len(state_cls.computed_vars), + "depth_from_root": depth, + } + + +def _sanitize_exception(exc: BaseException | None) -> _ExceptionInfo | None: + """Return a sanitized dict describing an exception, or ``None``. + + Only the exception class name is included. No message, no traceback, + no cause/context chain, no file paths. + + Args: + exc: The exception to sanitize, or ``None``. + + Returns: + ``None`` when ``exc`` is ``None``, else ``{"type": }``. + """ + if exc is None: + return None + return {"type": type(exc).__name__} diff --git a/reflex/utils/telemetry_context.py b/reflex/utils/telemetry_context.py new file mode 100644 index 00000000000..e16231957bf --- /dev/null +++ b/reflex/utils/telemetry_context.py @@ -0,0 +1,101 @@ +"""Compile-time telemetry context.""" + +from __future__ import annotations + +import dataclasses +import time +from typing import Any, Literal + +from reflex_base.config import get_config +from reflex_base.context.base import BaseContext + +CompileTrigger = Literal[ + "initial", "cli_compile", "backend_startup", "hot_reload", "export" +] + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False) +class TelemetryContext(BaseContext): + """Per-compile telemetry handle attached to the current contextvar.""" + + start_perf_counter: float = dataclasses.field(default_factory=time.perf_counter) + features_used: dict[str, Any] = dataclasses.field(default_factory=dict) + trigger: CompileTrigger | None = None + exception: BaseException | None = dataclasses.field(default=None, repr=False) + + # BaseContext is a fieldless frozen dataclass, so its generated __eq__/__hash__ + # treat any two same-class instances as equal. That collides in the + # _attached_context_token dict and breaks nested `with` use, so force identity. + def __eq__(self, other: object) -> bool: + """Identity equality. + + Args: + other: The object to compare against. + + Returns: + True iff ``other`` is the same instance. + """ + return self is other + + def __hash__(self) -> int: + """Identity-based hash. + + Returns: + A hash derived from the object's identity. + """ + return id(self) + + def set_exception(self, exc: BaseException | None) -> None: + """Attach an exception that occurred during this compile. + + Args: + exc: The exception to attach, or ``None`` to clear. + """ + object.__setattr__(self, "exception", exc) + + @classmethod + def get(cls) -> TelemetryContext | None: # pyright: ignore[reportIncompatibleMethodOverride] + """Return the active telemetry context, or None if none is attached. + + Returns: + The active ``TelemetryContext`` instance, or ``None`` when no + context is attached (i.e. telemetry is disabled this compile). + """ + try: + return cls._context_var.get() + except LookupError: + return None + + @classmethod + def start( + cls, + *, + telemetry_enabled: bool | None = None, + trigger: CompileTrigger | None = None, + ) -> TelemetryContext | None: + """Create a new context iff telemetry is enabled. + + Args: + telemetry_enabled: Whether telemetry is enabled. Read from the + config when ``None``. + trigger: Label identifying what initiated this compile. + + Returns: + A new ``TelemetryContext`` (not yet entered) or ``None`` when + telemetry is disabled. + """ + if telemetry_enabled is None: + telemetry_enabled = get_config().telemetry_enabled + if not telemetry_enabled: + return None + return cls(trigger=trigger) + + def elapsed_ms(self) -> int: + """Return the elapsed time since context construction in milliseconds. + + Returns: + The elapsed time in whole milliseconds. + """ + return int( + (time.perf_counter() - self.start_perf_counter) * 1000 + ) # seconds → milliseconds diff --git a/tests/units/test_app.py b/tests/units/test_app.py index bcd90c9db1e..10b275e1f53 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -39,18 +39,14 @@ from reflex import AdminDash, constants from reflex.app import App, ComponentCallable, upload from reflex.environment import environment +from reflex.istate.data import RouterData from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis from reflex.istate.manager.token import BaseStateToken from reflex.model import Model -from reflex.state import ( - BaseState, - OnLoadInternalState, - RouterData, - State, - reload_state_module, -) +from reflex.state import BaseState, OnLoadInternalState, State, reload_state_module +from reflex.utils import exec as exec_utils from .conftest import chdir from .states import GenState @@ -2901,3 +2897,158 @@ def _test(): EventContext.get() isolated_context.run(_test) + + +def test_compile_sends_telemetry_when_enabled( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """When telemetry is enabled, ``_compile`` emits one ``compile`` PostHog event.""" + conf = rx.Config(app_name="testing", telemetry_enabled=True) + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + + mocker.patch("reflex.compiler.compiler.compile_app", return_value=True) + send_mock = mocker.patch("reflex.utils.telemetry.send") + + app._compile(trigger="initial") + + compile_calls = [c for c in send_mock.call_args_list if c.args[0] == "compile"] + assert len(compile_calls) == 1 + payload = compile_calls[0].kwargs["properties"] + for key in ( + "plugins_enabled", + "plugins_disabled", + "pages_count", + "component_counts", + "states", + "features_used", + "duration_ms", + "trigger", + "exception", + ): + assert key in payload, f"missing key {key} in compile event payload" + assert payload["exception"] is None + assert payload["trigger"] == "initial" + + +def test_compile_skips_telemetry_when_disabled( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """When telemetry is disabled, ``_compile`` does not emit a ``compile`` event.""" + conf = rx.Config(app_name="testing", telemetry_enabled=False) + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + + mocker.patch("reflex.compiler.compiler.compile_app", return_value=True) + send_mock = mocker.patch("reflex.utils.telemetry.send") + + app._compile() + + assert all(c.args[0] != "compile" for c in send_mock.call_args_list) + + +def test_compile_reports_exception_and_reraises( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """A compile exception is sanitized into the event and then re-raised.""" + conf = rx.Config(app_name="testing", telemetry_enabled=True) + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + + class _BoomError(RuntimeError): + pass + + mocker.patch( + "reflex.compiler.compiler.compile_app", + side_effect=_BoomError("/etc/passwd: secret token foo"), + ) + send_mock = mocker.patch("reflex.utils.telemetry.send") + + with pytest.raises(_BoomError): + app._compile() + + compile_calls = [c for c in send_mock.call_args_list if c.args[0] == "compile"] + assert len(compile_calls) == 1 + payload = compile_calls[0].kwargs["properties"] + assert payload["exception"] == {"type": "_BoomError"} + + +def test_compile_skips_telemetry_when_compile_app_short_circuits( + compilable_app: tuple[App, Path], + mocker: MockerFixture, +): + """No ``compile`` event when ``compile_app()`` skipped the real compile.""" + conf = rx.Config(app_name="testing", telemetry_enabled=True) + mocker.patch("reflex_base.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + + mocker.patch("reflex.compiler.compiler.compile_app", return_value=False) + send_mock = mocker.patch("reflex.utils.telemetry.send") + + app._compile(trigger="backend_startup") + + assert all(c.args[0] != "compile" for c in send_mock.call_args_list) + + +def test_call_marks_first_dev_backend_worker_as_startup( + compilable_app: tuple[App, Path], + mocker: MockerFixture, + monkeypatch: pytest.MonkeyPatch, +): + """The first reload-capable backend worker compile is backend startup.""" + monkeypatch.setenv(environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.name, "True") + + app, web_dir = compilable_app + marker = web_dir / exec_utils.DEV_BACKEND_RELOAD_MARKER + compile_mock = mocker.patch.object(app, "_compile") + + app() + + compile_mock.assert_called_once() + assert compile_mock.call_args.kwargs["trigger"] == "backend_startup" + assert marker.exists() + + +def test_call_marks_later_dev_backend_worker_as_hot_reload( + compilable_app: tuple[App, Path], + mocker: MockerFixture, + monkeypatch: pytest.MonkeyPatch, +): + """A later reload-capable backend worker compile is a hot reload.""" + monkeypatch.setenv(environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.name, "True") + + app, web_dir = compilable_app + marker = web_dir / exec_utils.DEV_BACKEND_RELOAD_MARKER + marker.touch() + compile_mock = mocker.patch.object(app, "_compile") + + app() + + compile_mock.assert_called_once() + assert compile_mock.call_args.kwargs["trigger"] == "hot_reload" + + +def test_call_ignores_stale_marker_without_dev_backend_reload( + compilable_app: tuple[App, Path], + mocker: MockerFixture, + monkeypatch: pytest.MonkeyPatch, +): + """A stale marker alone is not enough to label a compile as hot reload.""" + monkeypatch.delenv(environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.name, raising=False) + + app, web_dir = compilable_app + marker = web_dir / exec_utils.DEV_BACKEND_RELOAD_MARKER + marker.touch() + compile_mock = mocker.patch.object(app, "_compile") + + app() + + compile_mock.assert_called_once() + assert compile_mock.call_args.kwargs["trigger"] == "backend_startup" diff --git a/tests/units/test_environment.py b/tests/units/test_environment.py index 8e7796b3ebe..9164ffe63d9 100644 --- a/tests/units/test_environment.py +++ b/tests/units/test_environment.py @@ -548,6 +548,10 @@ def test_internal_environment_variables(self): """Test internal environment variables have correct names.""" assert environment.REFLEX_COMPILE_CONTEXT.name == "__REFLEX_COMPILE_CONTEXT" assert environment.REFLEX_SKIP_COMPILE.name == "__REFLEX_SKIP_COMPILE" + assert ( + environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.name + == "__REFLEX_DEV_BACKEND_RELOAD_ACTIVE" + ) def test_performance_mode_enum(self): """Test PerformanceMode enum.""" diff --git a/tests/units/test_telemetry.py b/tests/units/test_telemetry.py index dbe307300ac..ac477e688eb 100644 --- a/tests/units/test_telemetry.py +++ b/tests/units/test_telemetry.py @@ -57,3 +57,83 @@ def test_send(mocker: MockerFixture, event): telemetry._send(event, telemetry_enabled=True) httpx_post_mock.assert_called_once() + + +def _make_mock_defaults(): + return { + "api_key": "test_api_key", + "properties": { + "distinct_id": 12345, + "distinct_app_id": 78285505863498957834586115958872998605, + "user_os": "Test OS", + "user_os_detail": "Mocked Platform", + "reflex_version": "0.8.0", + "python_version": "3.8.0", + "node_version": None, + "bun_version": None, + "reflex_enterprise_version": None, + "cpu_count": 4, + "memory": 8192, + "cpu_info": {}, + }, + } + + +def test_prepare_event_merges_properties(mocker: MockerFixture): + mocker.patch( + "reflex.utils.telemetry._get_event_defaults", + return_value=_make_mock_defaults(), + ) + + event = telemetry._prepare_event( + "compile", + properties={"pages_count": 7, "trigger": "initial"}, + ) + + assert event is not None + assert event["event"] == "compile" + props: dict = event["properties"] # pyright: ignore[reportAssignmentType] + assert props["pages_count"] == 7 + assert props["trigger"] == "initial" + # Existing default keys are preserved. + assert props["user_os"] == "Test OS" + + +def test_prepare_event_does_not_mutate_cached_defaults(mocker: MockerFixture): + """``_prepare_event`` must not mutate the @once_unless_none cached defaults.""" + cached = _make_mock_defaults() + mocker.patch( + "reflex.utils.telemetry._get_event_defaults", + return_value=cached, + ) + + cached_props_snapshot = dict(cached["properties"]) + + telemetry._prepare_event("init", template="my-template") + telemetry._prepare_event( + "compile", + properties={"pages_count": 3, "duration_ms": 42}, + ) + + assert cached["properties"] == cached_props_snapshot + assert "template" not in cached["properties"] + assert "pages_count" not in cached["properties"] + assert "duration_ms" not in cached["properties"] + + +def test_prepare_event_properties_override_kwargs(mocker: MockerFixture): + """If both kwargs and properties supply the same key, properties wins.""" + mocker.patch( + "reflex.utils.telemetry._get_event_defaults", + return_value=_make_mock_defaults(), + ) + + event = telemetry._prepare_event( + "init", + template="from-kwarg", + properties={"template": "from-properties"}, + ) + + assert event is not None + props: dict = event["properties"] # pyright: ignore[reportAssignmentType] + assert props["template"] == "from-properties" diff --git a/tests/units/utils/test_exec.py b/tests/units/utils/test_exec.py new file mode 100644 index 00000000000..3b752fa9b0f --- /dev/null +++ b/tests/units/utils/test_exec.py @@ -0,0 +1,80 @@ +"""Tests for development backend launchers in ``reflex.utils.exec``.""" + +import os +from pathlib import Path + +import pytest +from pytest_mock import MockerFixture +from reflex_base.environment import environment + +from reflex.utils import exec as exec_utils + +DEV_BACKEND_RELOAD_ENV_NAME = environment.REFLEX_DEV_BACKEND_RELOAD_ACTIVE.name + + +def test_run_uvicorn_backend_sets_reload_env_var_and_clears_marker( + tmp_path: Path, mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch +): + """``run_uvicorn_backend`` initializes reload worker process context.""" + marker = tmp_path / exec_utils.DEV_BACKEND_RELOAD_MARKER + marker.touch() + monkeypatch.delenv(DEV_BACKEND_RELOAD_ENV_NAME, raising=False) + mocker.patch.object( + exec_utils, "get_dev_backend_reload_marker", return_value=marker + ) + mocker.patch.object(exec_utils, "get_app_instance", return_value="app:app") + mocker.patch.object(exec_utils, "get_reload_paths", return_value=[]) + + seen: dict[str, str | None] = {} + + def fake_run(*_args, **_kwargs): + seen["value"] = os.environ.get(DEV_BACKEND_RELOAD_ENV_NAME) + assert not marker.exists() + + uvicorn = pytest.importorskip("uvicorn") + mocker.patch.object(uvicorn, "run", side_effect=fake_run) + + exec_utils.run_uvicorn_backend( + host="0.0.0.0", port=8000, loglevel=exec_utils.LogLevel.INFO + ) + + assert seen["value"] == "True" + + +def test_run_granian_backend_sets_reload_env_var_and_clears_marker( + tmp_path: Path, mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch +): + """``run_granian_backend`` initializes reload worker process context.""" + marker = tmp_path / exec_utils.DEV_BACKEND_RELOAD_MARKER + marker.touch() + monkeypatch.delenv(DEV_BACKEND_RELOAD_ENV_NAME, raising=False) + mocker.patch.object( + exec_utils, "get_dev_backend_reload_marker", return_value=marker + ) + mocker.patch.object( + exec_utils, "get_app_instance_from_file", return_value="app:app" + ) + mocker.patch.object(exec_utils, "get_reload_paths", return_value=[]) + + seen: dict[str, str | None] = {} + + granian_server = pytest.importorskip("granian.server") + + class FakeGranian: + def __init__(self, *_args, **_kwargs): + seen["value"] = os.environ.get(DEV_BACKEND_RELOAD_ENV_NAME) + assert not marker.exists() + + def on_reload(self, _callback): + pass + + def serve(self): + pass + + mocker.patch.object(granian_server, "Server", FakeGranian) + + exec_utils.run_granian_backend( + host="0.0.0.0", port=8000, loglevel=exec_utils.LogLevel.INFO + ) + + assert seen["value"] == "True" diff --git a/tests/units/utils/test_telemetry_accounting.py b/tests/units/utils/test_telemetry_accounting.py new file mode 100644 index 00000000000..d36f6ee2891 --- /dev/null +++ b/tests/units/utils/test_telemetry_accounting.py @@ -0,0 +1,218 @@ +"""Tests for ``reflex.utils.telemetry_accounting``.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +from pytest_mock import MockerFixture +from reflex_base.plugins.sitemap import SitemapPlugin + +import reflex as rx +from reflex.state import ( + BaseState, + FrontendEventExceptionState, + OnLoadInternalState, + State, + UpdateVarsInternalState, +) +from reflex.utils import telemetry_accounting +from reflex.utils.telemetry_context import TelemetryContext + + +class _TelAcctRoot(BaseState): + """Root state for accounting tests.""" + + a: int = 0 + b: str = "" + _backend_one: int = 0 + + @rx.event + def set_a(self, value: int): + """Set a. + + Args: + value: New value. + """ + self.a = value + + @rx.var + def doubled(self) -> int: + """Doubled value of a. + + Returns: + ``a * 2``. + """ + return self.a * 2 + + +class _TelAcctChild(_TelAcctRoot): + """Child state for accounting tests.""" + + c: int = 0 + + +class _TelAcctGrandchild(_TelAcctChild): + """Grandchild state for accounting tests.""" + + d: int = 0 + + +def test_sanitize_exception_strips_message_and_path(): + """Sanitization keeps only the class name, dropping any sensitive message.""" + exc = ValueError("/secret/path: bad value 'topsecret'") + assert telemetry_accounting._sanitize_exception(exc) == {"type": "ValueError"} + + +def test_count_components_walks_nested_tree(): + """Component counts include every node in nested trees, keyed by class name.""" + tree = rx.box(rx.box(rx.box())) + counts = telemetry_accounting._count_components([tree]) + assert counts[type(tree).__name__] == 3 + + +def test_collect_state_stats_root_depth_zero(): + """A root state has depth 0 and reports the counts straight off the class.""" + stats = telemetry_accounting._collect_state_stats(_TelAcctRoot) + assert stats == { + "event_handlers_count": len(_TelAcctRoot.event_handlers), + "vars_count": len(_TelAcctRoot.vars), + "backend_vars_count": len(_TelAcctRoot.backend_vars), + "computed_vars_count": len(_TelAcctRoot.computed_vars), + "depth_from_root": 0, + } + + +def test_collect_state_stats_depth_hierarchy(): + """Depth increases with each parent-to-child step.""" + assert ( + telemetry_accounting._collect_state_stats(_TelAcctChild)["depth_from_root"] == 1 + ) + assert ( + telemetry_accounting._collect_state_stats(_TelAcctGrandchild)["depth_from_root"] + == 2 + ) + + +def test_walk_states_yields_root_and_descendants(): + """The walker reaches every descendant transitively.""" + walked = set(telemetry_accounting._walk_states(_TelAcctRoot)) + assert {_TelAcctRoot, _TelAcctChild, _TelAcctGrandchild} <= walked + + +def test_collect_compile_event_payload_shape(mocker: MockerFixture): + """The payload exposes every documented field with the expected types.""" + fake_plugin = MagicMock() + fake_plugin.__class__.__name__ = "FakePlugin" + mocker.patch( + "reflex.utils.telemetry_accounting.get_config", + return_value=mocker.Mock( + plugins=[fake_plugin], disable_plugins=[SitemapPlugin] + ), + ) + + app = SimpleNamespace( + _state=_TelAcctRoot, + _pages={"/": rx.box(rx.text("hello"))}, + ) + ctx = TelemetryContext(trigger="cli_compile") + ctx.features_used["radix"] = True + + payload = telemetry_accounting._collect_compile_event_payload( + app, # pyright: ignore[reportArgumentType] + ctx, + ) + + assert payload["plugins_enabled"] == ["FakePlugin"] + assert payload["plugins_disabled"] == ["SitemapPlugin"] + assert payload["pages_count"] == 1 + assert payload["component_counts"] + assert any(s["depth_from_root"] == 0 for s in payload["states"]) + assert payload["features_used"] == {"radix": True} + assert payload["duration_ms"] >= 0 + assert payload["trigger"] == "cli_compile" + assert payload["exception"] is None + + +def test_collect_compile_event_payload_with_exception(mocker: MockerFixture): + """An attached exception lands in the payload as a sanitized type-only dict.""" + mocker.patch( + "reflex.utils.telemetry_accounting.get_config", + return_value=mocker.Mock(plugins=[], disable_plugins=[]), + ) + + app = SimpleNamespace(_state=None, _pages={}) + ctx = TelemetryContext() + ctx.set_exception(RuntimeError("oops")) + + payload = telemetry_accounting._collect_compile_event_payload( + app, # pyright: ignore[reportArgumentType] + ctx, + ) + assert payload["exception"] == {"type": "RuntimeError"} + assert payload["pages_count"] == 0 + assert payload["states"] == [] + + +def test_collect_compile_event_payload_snapshots_features_used(mocker: MockerFixture): + """features_used in the payload is a snapshot, immune to later mutation.""" + mocker.patch( + "reflex.utils.telemetry_accounting.get_config", + return_value=mocker.Mock(plugins=[], disable_plugins=[]), + ) + app = SimpleNamespace(_state=None, _pages={}) + ctx = TelemetryContext() + ctx.features_used["x"] = 1 + + payload = telemetry_accounting._collect_compile_event_payload( + app, # pyright: ignore[reportArgumentType] + ctx, + ) + ctx.features_used["x"] = 999 + ctx.features_used["y"] = 2 + assert payload["features_used"] == {"x": 1} + + +def test_walk_states_skips_framework_internal_substates(): + """Framework-internal substates are excluded; user states still appear.""" + + class _UserWalkState(rx.State): + x: int = 0 + + walked = list(telemetry_accounting._walk_states(State)) + walked_names = {cls.__name__ for cls in walked} + + assert UpdateVarsInternalState not in walked + assert OnLoadInternalState not in walked + assert FrontendEventExceptionState not in walked + assert "SharedStateBaseInternal" not in walked_names + assert State not in walked + assert _UserWalkState in walked + + +def test_memo_wrapper_class_records_wrapped_component_type(): + """The dynamic memo subclass exposes the user-authored component class.""" + import importlib + + from reflex_components_radix.themes.components.button import Button + + memo_module = importlib.import_module("reflex.experimental.memo") + + wrapper_cls = memo_module._get_experimental_memo_component_class( + "Button_button_deadbeefcafebabe", + Button, + ) + assert wrapper_cls._wrapped_component_type is Button + + +def test_count_components_buckets_memo_wrapper_by_wrapped_type(): + """Memo wrappers count under their wrapped component class name.""" + from reflex_components_radix.themes.components.button import Button + + class _StubMemoWrapper: + _wrapped_component_type = Button + children = () + + counts = telemetry_accounting._count_components( + [_StubMemoWrapper()], # pyright: ignore[reportArgumentType] + ) + + assert counts == {"Button": 1} diff --git a/tests/units/utils/test_telemetry_context.py b/tests/units/utils/test_telemetry_context.py new file mode 100644 index 00000000000..3e1f21ccfe5 --- /dev/null +++ b/tests/units/utils/test_telemetry_context.py @@ -0,0 +1,108 @@ +"""Tests for ``reflex.utils.telemetry_context``.""" + +from pytest_mock import MockerFixture + +from reflex.utils.telemetry_context import TelemetryContext + + +def test_get_returns_none_when_no_context_set(): + """``get()`` returns ``None`` instead of raising ``LookupError``.""" + assert TelemetryContext.get() is None + + +def test_start_returns_none_when_telemetry_disabled(mocker: MockerFixture): + """``start()`` short-circuits when the config has telemetry disabled.""" + mocker.patch( + "reflex.utils.telemetry_context.get_config", + return_value=mocker.Mock(telemetry_enabled=False), + ) + assert TelemetryContext.start() is None + + +def test_start_returns_context_when_telemetry_enabled(mocker: MockerFixture): + """``start()`` returns a fresh context when telemetry is enabled.""" + mocker.patch( + "reflex.utils.telemetry_context.get_config", + return_value=mocker.Mock(telemetry_enabled=True), + ) + assert isinstance(TelemetryContext.start(), TelemetryContext) + + +def test_start_explicit_telemetry_enabled_overrides_config(): + """An explicit ``telemetry_enabled`` argument bypasses the config lookup.""" + assert TelemetryContext.start(telemetry_enabled=False) is None + assert isinstance(TelemetryContext.start(telemetry_enabled=True), TelemetryContext) + + +def test_context_manager_attaches_and_detaches(): + """Entering the context binds it to ``get()``; exiting clears it.""" + ctx = TelemetryContext() + with ctx: + assert TelemetryContext.get() is ctx + assert TelemetryContext.get() is None + + +def test_elapsed_ms_is_non_negative(): + """``elapsed_ms()`` returns a non-negative integer immediately after creation.""" + ctx = TelemetryContext() + elapsed = ctx.elapsed_ms() + assert isinstance(elapsed, int) + assert elapsed >= 0 + + +def test_set_exception_records_value_on_frozen_dataclass(): + """``set_exception`` mutates the otherwise-frozen ``exception`` field.""" + ctx = TelemetryContext() + exc = ValueError("boom") + ctx.set_exception(exc) + assert ctx.exception is exc + + +def test_features_used_writable_via_get(): + """Writes through ``get()`` are visible on the original context instance.""" + ctx = TelemetryContext() + with ctx: + active = TelemetryContext.get() + assert active is ctx + assert active is not None + active.features_used["foo"] = True + assert ctx.features_used == {"foo": True} + + +def test_trigger_stored_on_context(): + """``start(trigger=...)`` round-trips the trigger onto the context.""" + ctx = TelemetryContext.start(telemetry_enabled=True, trigger="backend_startup") + assert ctx is not None + assert ctx.trigger == "backend_startup" + + +def test_distinct_contexts_use_identity_equality(): + """Two ``TelemetryContext`` instances must not compare equal or share a hash. + + ``BaseContext`` uses a class-level dict keyed by ``self`` to track attached + contexts, so identity-based equality is required for nested use to work. + """ + a = TelemetryContext() + b = TelemetryContext() + assert a != b + assert hash(a) != hash(b) + assert a == a + + +def test_nested_contexts_can_be_entered(): + """Nested ``with`` blocks attach and detach without colliding.""" + outer = TelemetryContext() + inner = TelemetryContext() + with outer: + assert TelemetryContext.get() is outer + with inner: + assert TelemetryContext.get() is inner + assert TelemetryContext.get() is outer + assert TelemetryContext.get() is None + + +def test_hot_reload_trigger_accepted(): + """The ``hot_reload`` value is a valid ``CompileTrigger`` and round-trips.""" + ctx = TelemetryContext.start(telemetry_enabled=True, trigger="hot_reload") + assert ctx is not None + assert ctx.trigger == "hot_reload"