diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 9e937ed62cd..867ca2d924f 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -726,6 +726,53 @@ export const ReflexEvent = ( return { name, payload, handler, event_actions }; }; +/** + * Apply event actions before invoking a target function. + * @param {Function} target The function to invoke after applying event actions. + * @param {Object.} event_actions The actions to apply. + * @param {Array|any} args The event args. + * @param {string|null} action_key A stable key for debounce/throttle tracking. + * @param {Function|null} temporal_handler Returns whether temporal actions may run. + * @returns The target result, if it runs immediately. + */ +export const applyEventActions = ( + target, + event_actions = {}, + args = [], + action_key = null, + temporal_handler = null, +) => { + if (!(args instanceof Array)) { + args = [args]; + } + + const _e = args.find((o) => o?.preventDefault !== undefined); + + if (event_actions?.preventDefault && _e?.preventDefault) { + _e.preventDefault(); + } + if (event_actions?.stopPropagation && _e?.stopPropagation) { + _e.stopPropagation(); + } + if (event_actions?.temporal && temporal_handler && !temporal_handler()) { + return; + } + + const invokeTarget = () => target(...args); + const resolved_action_key = action_key ?? target.toString(); + + if (event_actions?.throttle) { + if (!throttle(resolved_action_key, event_actions.throttle)) { + return; + } + } + if (event_actions?.debounce) { + debounce(resolved_action_key, invokeTarget, event_actions.debounce); + return; + } + return invokeTarget(); +}; + /** * Package client-side storage values as payload to send to the * backend with the hydrate event @@ -898,51 +945,24 @@ export const useEventLoop = ( // Function to add new events to the event queue. const addEvents = useCallback((events, args, event_actions) => { const _events = events.filter((e) => e !== undefined && e !== null); - if (!event_actions?.temporal) { - // Reconnect socket if needed for non-temporal events. - ensureSocketConnected(); - } - - if (!(args instanceof Array)) { - args = [args]; - } event_actions = _events.reduce( (acc, e) => ({ ...acc, ...e.event_actions }), event_actions ?? {}, ); - const _e = args.filter((o) => o?.preventDefault !== undefined)[0]; - - if (event_actions?.preventDefault && _e?.preventDefault) { - _e.preventDefault(); - } - if (event_actions?.stopPropagation && _e?.stopPropagation) { - _e.stopPropagation(); - } - const combined_name = _events.map((e) => e.name).join("+++"); - if (event_actions?.temporal) { - if (!socket.current || !socket.current.connected) { - return; // don't queue when the backend is not connected - } - } - if (event_actions?.throttle) { - // If throttle returns false, the events are not added to the queue. - if (!throttle(combined_name, event_actions.throttle)) { - return; - } - } - if (event_actions?.debounce) { - // If debounce is used, queue the events after some delay - debounce( - combined_name, - () => - queueEvents(_events, socket, false, navigate, () => params.current), - event_actions.debounce, - ); - } else { - queueEvents(_events, socket, false, navigate, () => params.current); + if (!event_actions?.temporal) { + // Reconnect socket if needed for non-temporal events. + ensureSocketConnected(); } + + return applyEventActions( + () => queueEvents(_events, socket, false, navigate, () => params.current), + event_actions, + args, + _events.map((e) => e.name).join("+++"), + () => !!socket.current?.connected, + ); }, []); const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode diff --git a/reflex/app.py b/reflex/app.py index 9f4ef9c6229..af26d68af10 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1934,11 +1934,6 @@ async def upload_file(request: Request): Returns: StreamingResponse yielding newline-delimited JSON of StateUpdate emitted by the upload handler. - - Raises: - UploadValueError: if there are no args with supported annotation. - UploadTypeError: if a background task is used as the handler. - HTTPException: when the request does not include token / handler headers. """ from reflex.utils.exceptions import UploadTypeError, UploadValueError @@ -1963,6 +1958,11 @@ async def _create_upload_event() -> Event: Returns: The upload event backed by the original temp files. + + Raises: + UploadValueError: If there are no uploaded files or supported annotations. + UploadTypeError: If a background task is used as the handler. + HTTPException: If the request is missing token or handler headers. """ files = form_data.getlist("files") if not files: diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 873cce69a14..0926e57f320 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -61,6 +61,8 @@ class CompileVars(SimpleNamespace): IS_HYDRATED = "is_hydrated" # The name of the function to add events to the queue. ADD_EVENTS = "addEvents" + # The name of the function to apply event actions before invoking a target. + APPLY_EVENT_ACTIONS = "applyEventActions" # The name of the var storing any connection error. CONNECT_ERROR = "connectErrors" # The name of the function for converting a dict to an event. @@ -128,7 +130,10 @@ class Imports(SimpleNamespace): EVENTS = { "react": [ImportVar(tag="useContext")], f"$/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")], - f"$/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)], + f"$/{Dirs.STATE_PATH}": [ + ImportVar(tag=CompileVars.TO_EVENT), + ImportVar(tag=CompileVars.APPLY_EVENT_ACTIONS), + ], } diff --git a/reflex/event.py b/reflex/event.py index a1592a29d9d..4a4a78a5fb3 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1,9 +1,12 @@ """Define event classes to connect the frontend and backend.""" +from __future__ import annotations + import dataclasses import inspect import sys import types +import warnings from base64 import b64encode from collections.abc import Callable, Mapping, Sequence from functools import lru_cache, partial @@ -87,6 +90,8 @@ def substate_token(self) -> str: _EVENT_FIELDS: set[str] = {f.name for f in dataclasses.fields(Event)} +_EMPTY_EVENTS = LiteralVar.create([]) +_EMPTY_EVENT_ACTIONS = LiteralVar.create({}) BACKGROUND_TASK_MARKER = "_reflex_background_task" EVENT_ACTIONS_MARKER = "_rx_event_actions" @@ -234,7 +239,7 @@ def is_background(self) -> bool: """ return getattr(self.fn, BACKGROUND_TASK_MARKER, False) - def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec": + def __call__(self, *args: Any, **kwargs: Any) -> EventSpec: """Pass arguments to the handler to get an event spec. This method configures event handlers that take in arguments. @@ -342,7 +347,7 @@ def __init__( object.__setattr__(self, "client_handler_name", client_handler_name) object.__setattr__(self, "args", args or ()) - def with_args(self, args: tuple[tuple[Var, Var], ...]) -> "EventSpec": + def with_args(self, args: tuple[tuple[Var, Var], ...]) -> EventSpec: """Copy the event spec, with updated args. Args: @@ -358,7 +363,7 @@ def with_args(self, args: tuple[tuple[Var, Var], ...]) -> "EventSpec": event_actions=self.event_actions.copy(), ) - def add_args(self, *args: Var) -> "EventSpec": + def add_args(self, *args: Var) -> EventSpec: """Add arguments to the event spec. Args: @@ -449,8 +454,8 @@ def __call__(self, *args, **kwargs) -> EventSpec: class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" - events: Sequence["EventSpec | EventVar | EventCallback"] = dataclasses.field( - default_factory=list + events: Sequence[EventSpec | EventVar | FunctionVar | EventCallback] = ( + dataclasses.field(default_factory=list) ) args_spec: Callable | Sequence[Callable] | None = dataclasses.field(default=None) @@ -460,11 +465,11 @@ class EventChain(EventActionsMixin): @classmethod def create( cls, - value: "EventType", + value: EventType, args_spec: ArgsSpec | Sequence[ArgsSpec], key: str | None = None, **event_chain_kwargs, - ) -> "EventChain | Var": + ) -> EventChain | Var: """Create an event chain from a variety of input types. Args: @@ -481,9 +486,18 @@ def create( """ # If it's an event chain var, return it. if isinstance(value, Var): - if isinstance(value, EventChainVar): + # Only pass through literal/prebuilt chains. Other EventChainVar values may be + # FunctionVars cast with `.to(EventChain)` and still need wrapping so + # event_chain_kwargs can compose onto the resulting chain. + if isinstance(value, LiteralEventChainVar): + if event_chain_kwargs: + warnings.warn( + f"event_chain_kwargs {event_chain_kwargs!r} are ignored for " + "EventChainVar values.", + stacklevel=2, + ) return value - if isinstance(value, EventVar): + if isinstance(value, (EventVar, FunctionVar)): value = [value] elif safe_issubclass(value._var_type, (EventChain, EventSpec)): return cls.create( @@ -503,38 +517,29 @@ def create( if isinstance(value, (EventHandler, EventSpec)): value = [value] + events: list[EventSpec | EventVar | FunctionVar] = [] + # If the input is a list of event handlers, create an event chain. if isinstance(value, list): - events: list[EventSpec | EventVar] = [] for v in value: if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. events.append(call_event_handler(v, args_spec, key=key)) + elif isinstance(v, (EventVar, EventChainVar)): + events.append(v) + elif isinstance(v, FunctionVar): + # Apply the args_spec transformations as partial arguments to the function. + events.append(v.partial(*parse_args_spec(args_spec)[0])) elif isinstance(v, Callable): # Call the lambda to get the event chain. - result = call_event_fn(v, args_spec, key=key) - if isinstance(result, Var): - msg = ( - f"Invalid event chain: {v}. Cannot use a Var-returning " - "lambda inside an EventChain list." - ) - raise ValueError(msg) - events.extend(result) - elif isinstance(v, EventVar): - events.append(v) + events.extend(call_event_fn(v, args_spec, key=key)) else: msg = f"Invalid event: {v}" raise ValueError(msg) # If the input is a callable, create an event chain. elif isinstance(value, Callable): - result = call_event_fn(value, args_spec, key=key) - if isinstance(result, Var): - # Recursively call this function if the lambda returned an EventChain Var. - return cls.create( - value=result, args_spec=args_spec, key=key, **event_chain_kwargs - ) - events = [*result] + events.extend(call_event_fn(value, args_spec, key=key)) # Otherwise, raise an error. else: @@ -1309,7 +1314,7 @@ def download( def call_script( javascript_code: str | Var[str], - callback: "EventType[Any] | None" = None, + callback: EventType[Any] | None = None, ) -> EventSpec: """Create an event handler that executes arbitrary javascript code. @@ -1350,7 +1355,7 @@ def call_script( def call_function( javascript_code: str | Var, - callback: "EventType[Any] | None" = None, + callback: EventType[Any] | None = None, ) -> EventSpec: """Create an event handler that executes arbitrary javascript code. @@ -1386,7 +1391,7 @@ def call_function( def run_script( javascript_code: str | Var, - callback: "EventType[Any] | None" = None, + callback: EventType[Any] | None = None, ) -> EventSpec: """Create an event handler that executes arbitrary javascript code. @@ -1404,7 +1409,7 @@ def run_script( return call_function(ArgsFunctionOperation.create((), javascript_code), callback) -def get_event(state: "BaseState", event: str): +def get_event(state: BaseState, event: str): """Get the event from the given state. Args: @@ -1417,7 +1422,7 @@ def get_event(state: "BaseState", event: str): return f"{state.get_name()}.{event}" -def get_hydrate_event(state: "BaseState") -> str: +def get_hydrate_event(state: BaseState) -> str: """Get the name of the hydrate event for the state. Args: @@ -1770,11 +1775,11 @@ def call_event_fn( fn: Callable, arg_spec: ArgsSpec | Sequence[ArgsSpec], key: str | None = None, -) -> list[EventSpec] | Var: +) -> list[EventSpec | FunctionVar | EventVar]: """Call a function to a list of event specs. - The function should return a single EventSpec, a list of EventSpecs, or a - single Var. + The function should return a single event-like value or a heterogeneous + sequence of event-like values. Args: fn: The function to call. @@ -1782,7 +1787,7 @@ def call_event_fn( key: The key to pass to the event handler. Returns: - The event specs from calling the function or a Var. + The event-like values from calling the function. Raises: EventHandlerValueError: If the lambda returns an unusable value. @@ -1803,13 +1808,9 @@ def call_event_fn( # Call the function with the parsed args. out = fn(*[*parsed_args][:number_of_fn_args]) - # If the function returns a Var, assume it's an EventChain and render it directly. - if isinstance(out, Var): - return out - - # Convert the output to a list. - if not isinstance(out, list): - out = [out] + # Normalize common heterogeneous event collections into individual events + # while keeping other scalar values for validation below. + out = list(out) if isinstance(out, (list, tuple)) else [out] # Convert any event specs to event specs. events = [] @@ -1818,9 +1819,21 @@ def call_event_fn( # An un-called EventHandler gets all of the args of the event trigger. e = call_event_handler(e, arg_spec, key=key) + if isinstance(e, EventChain): + # Nested EventChain is treated like a FunctionVar. + e = Var.create(e) + # Make sure the event spec is valid. - if not isinstance(e, EventSpec): - msg = f"Lambda {fn} returned an invalid event spec: {e}." + if not isinstance(e, (EventSpec, FunctionVar, EventVar)): + hint = "" + if isinstance(e, VarOperationCall): + hint = " Hint: use `fn.partial(...)` instead of calling the FunctionVar directly." + msg = ( + f"Invalid event chain for {key}: {fn} -> {e}: A lambda inside an EventChain " + "list must return `EventSpec | EventHandler | EventChain | EventVar | FunctionVar` " + "or a heterogeneous sequence of these types. " + f"Got: {type(e)}.{hint}" + ) raise EventHandlerValueError(msg) # Add the event spec to the chain. @@ -1969,7 +1982,7 @@ def create( cls, value: EventSpec | EventHandler, _var_data: VarData | None = None, - ) -> "LiteralEventVar": + ) -> LiteralEventVar: """Create a new LiteralEventVar instance. Args: @@ -2056,7 +2069,7 @@ def create( cls, value: EventChain, _var_data: VarData | None = None, - ) -> "LiteralEventChainVar": + ) -> LiteralEventChainVar: """Create a new LiteralEventChainVar instance. Args: @@ -2075,9 +2088,11 @@ def create( else value.args_spec ) sig = inspect.signature(arg_spec) # pyright: ignore [reportArgumentType] + arg_vars = () if sig.parameters: arg_def = tuple(f"_{p}" for p in sig.parameters) - arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) + arg_vars = tuple(Var(_js_expr=arg) for arg in arg_def) + arg_def_expr = LiteralVar.create(list(arg_vars)) else: # add a default argument for addEvents if none were specified in value.args_spec # used to trigger the preventDefault() on the event. @@ -2098,17 +2113,59 @@ def create( if invocation is not None and not isinstance(invocation, FunctionVar): msg = f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}." raise ValueError(msg) + assert invocation is not None + + call_args = arg_vars if sig.parameters else (Var(_js_expr="...args"),) + statements = [ + ( + event.call(*call_args) + if isinstance(event, FunctionVar) + else invocation.call( + LiteralVar.create([LiteralVar.create(event)]), + arg_def_expr, + _EMPTY_EVENT_ACTIONS, + ) + ) + for event in value.events + ] + + if not statements: + statements.append( + invocation.call( + _EMPTY_EVENTS, + arg_def_expr, + _EMPTY_EVENT_ACTIONS, + ) + ) + + if len(statements) == 1 and not value.event_actions: + return_expr = statements[0] + else: + statement_block = Var( + _js_expr=f"{{{''.join(f'{statement};' for statement in statements)}}}", + ) + if value.event_actions: + apply_event_actions = FunctionStringVar.create( + CompileVars.APPLY_EVENT_ACTIONS, + _var_data=VarData( + imports=Imports.EVENTS, + hooks={Hooks.EVENTS: None}, + ), + ) + return_expr = apply_event_actions.call( + ArgsFunctionOperation.create((), statement_block), + value.event_actions, + *call_args, + ) + else: + return_expr = statement_block return cls( _js_expr="", _var_type=EventChain, _var_data=_var_data, _args=FunctionArgs(arg_def), - _return_expr=invocation.call( - LiteralVar.create([LiteralVar.create(event) for event in value.events]), - arg_def_expr, - value.event_actions, - ), + _return_expr=return_expr, _var_value=value, ) @@ -2135,39 +2192,39 @@ def __init__(self, func: Callable[[Any, Unpack[P]], Any]): @overload def __call__( - self: "EventCallback[Unpack[Q]]", - ) -> "EventCallback[Unpack[Q]]": ... + self: EventCallback[Unpack[Q]], + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: "EventCallback[V, Unpack[Q]]", value: V | Var[V] - ) -> "EventCallback[Unpack[Q]]": ... + self: EventCallback[V, Unpack[Q]], value: V | Var[V] + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: "EventCallback[V, V2, Unpack[Q]]", + self: EventCallback[V, V2, Unpack[Q]], value: V | Var[V], value2: V2 | Var[V2], - ) -> "EventCallback[Unpack[Q]]": ... + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: "EventCallback[V, V2, V3, Unpack[Q]]", + self: EventCallback[V, V2, V3, Unpack[Q]], value: V | Var[V], value2: V2 | Var[V2], value3: V3 | Var[V3], - ) -> "EventCallback[Unpack[Q]]": ... + ) -> EventCallback[Unpack[Q]]: ... @overload def __call__( - self: "EventCallback[V, V2, V3, V4, Unpack[Q]]", + self: EventCallback[V, V2, V3, V4, Unpack[Q]], value: V | Var[V], value2: V2 | Var[V2], value3: V3 | Var[V3], value4: V4 | Var[V4], - ) -> "EventCallback[Unpack[Q]]": ... + ) -> EventCallback[Unpack[Q]]: ... - def __call__(self, *values) -> "EventCallback": # pyright: ignore [reportInconsistentOverload] + def __call__(self, *values) -> EventCallback: # pyright: ignore [reportInconsistentOverload] """Call the function with the values. Args: @@ -2180,11 +2237,11 @@ def __call__(self, *values) -> "EventCallback": # pyright: ignore [reportIncons @overload def __get__( - self: "EventCallback[Unpack[P]]", instance: None, owner: Any - ) -> "EventCallback[Unpack[P]]": ... + self: EventCallback[Unpack[P]], instance: None, owner: Any + ) -> EventCallback[Unpack[P]]: ... @overload - def __get__(self, instance: Any, owner: Any) -> "Callable[[Unpack[P]]]": ... + def __get__(self, instance: Any, owner: Any) -> Callable[[Unpack[P]]]: ... def __get__(self, instance: Any, owner: Any) -> Callable: """Get the function with the instance bound to it. @@ -2208,19 +2265,19 @@ class LambdaEventCallback(Protocol[Unpack[P]]): __code__: types.CodeType @overload - def __call__(self: "LambdaEventCallback[()]") -> Any: ... + def __call__(self: LambdaEventCallback[()]) -> Any: ... @overload - def __call__(self: "LambdaEventCallback[V]", value: "Var[V]", /) -> Any: ... + def __call__(self: LambdaEventCallback[V], value: Var[V], /) -> Any: ... @overload def __call__( - self: "LambdaEventCallback[V, V2]", value: Var[V], value2: Var[V2], / + self: LambdaEventCallback[V, V2], value: Var[V], value2: Var[V2], / ) -> Any: ... @overload def __call__( - self: "LambdaEventCallback[V, V2, V3]", + self: LambdaEventCallback[V, V2, V3], value: Var[V], value2: Var[V2], value3: Var[V3], diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 6836e9d3b49..709b9b03b05 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -33,6 +33,84 @@ class ReflexCallable(Protocol[P, R]): ) +def _is_js_identifier_start(char: str) -> bool: + """Check whether a character can start a JavaScript identifier. + + Returns: + True if the character is valid as the first character of a JS identifier. + """ + return char == "$" or char == "_" or char.isalpha() + + +def _is_js_identifier_char(char: str) -> bool: + """Check whether a character can continue a JavaScript identifier. + + Returns: + True if the character is valid within a JS identifier. + """ + return _is_js_identifier_start(char) or char.isdigit() + + +def _starts_with_arrow_function(expr: str) -> bool: + """Check whether an expression starts with an inline arrow function. + + Returns: + True if the expression begins with an arrow function. + """ + if "=>" not in expr: + return False + + expr = expr.lstrip() + if not expr: + return False + + if expr.startswith("async"): + async_remainder = expr[len("async") :] + if async_remainder[:1].isspace(): + expr = async_remainder.lstrip() + + if not expr: + return False + + if _is_js_identifier_start(expr[0]): + end_index = 1 + while end_index < len(expr) and _is_js_identifier_char(expr[end_index]): + end_index += 1 + return expr[end_index:].lstrip().startswith("=>") + + if not expr.startswith("("): + return False + + depth = 0 + string_delimiter: str | None = None + escaped = False + + for index, char in enumerate(expr): + if string_delimiter is not None: + if escaped: + escaped = False + elif char == "\\": + escaped = True + elif char == string_delimiter: + string_delimiter = None + continue + + if char in {"'", '"', "`"}: + string_delimiter = char + continue + + if char == "(": + depth += 1 + continue + + if char == ")": + depth -= 1 + if depth == 0: + return expr[index + 1 :].lstrip().startswith("=>") + + return False + + class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): """Base class for immutable function vars.""" @@ -108,7 +186,7 @@ def partial(self, *args: Var | Any) -> FunctionVar: # pyright: ignore [reportIn The partially applied function. """ if not args: - return ArgsFunctionOperation.create((), self) + return self return ArgsFunctionOperation.create( ("...args",), VarOperationCall.create(self, *args, Var(_js_expr="...args")), @@ -239,7 +317,13 @@ def _cached_var_name(self) -> str: Returns: The name of the var. """ - return f"({self._func!s}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" + func_expr = str(self._func) + if _starts_with_arrow_function(func_expr) and not format.is_wrapped( + func_expr, "(" + ): + func_expr = format.wrap(func_expr, "(") + + return f"({func_expr}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" @cached_property_no_lock def _cached_get_all_var_data(self) -> VarData | None: diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 959c81c34bb..925873703a8 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -931,10 +931,6 @@ def test_invalid_event_handler_args(component2, test_state: type[TestState]): component2.create(on_blur=lambda: 1) with pytest.raises(ValueError): component2.create(on_blur=lambda: [1]) - with pytest.raises(ValueError): - component2.create( - on_blur=lambda: (test_state.do_something_arg(1), test_state.do_something) - ) # lambda signature must match event trigger. with pytest.raises(EventFnArgMismatchError): @@ -1007,6 +1003,9 @@ def test_valid_event_handler_args(component2, test_state: type[TestState]): component2.create( on_blur=lambda: [test_state.do_something_arg(1), test_state.do_something] ) + component2.create( + on_blur=lambda: (test_state.do_something_arg(1), test_state.do_something) + ) component2.create( on_blur=lambda: [test_state.do_something_arg(1), test_state.do_something()] ) diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 25c71c0d17e..d01fe5de076 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1719,6 +1719,7 @@ def test_app_wrap_compile_theme( "export function Layout" ) ].strip() + expected = ( "function AppWrap({children}) {\n" "const [addEvents, connectErrors] = useContext(EventLoopContext);\n\n\n\n" @@ -1727,15 +1728,15 @@ def test_app_wrap_compile_theme( + "jsx(ErrorBoundary,{" """fallbackRender:((event_args) => (jsx("div", ({css:({ ["height"] : "100%", ["width"] : "100%", ["position"] : "absolute", ["backgroundColor"] : "#fff", ["color"] : "#000", ["display"] : "flex", ["alignItems"] : "center", ["justifyContent"] : "center" })}), (jsx("div", ({css:({ ["display"] : "flex", ["flexDirection"] : "column", ["gap"] : "0.5rem", ["maxWidth"] : "min(80ch, 90vw)", ["borderRadius"] : "0.25rem", ["padding"] : "1rem" })}), (jsx("div", ({css:({ ["opacity"] : "0.5", ["display"] : "flex", ["gap"] : "4vmin", ["alignItems"] : "center" })}), (jsx("svg", ({className:"lucide lucide-frown-icon lucide-frown",fill:"none",stroke:"currentColor","stroke-linecap":"round","stroke-linejoin":"round","stroke-width":"2",viewBox:"0 0 24 24",width:"25vmin",xmlns:"http://www.w3.org/2000/svg"}), (jsx("circle", ({cx:"12",cy:"12",r:"10"}))), (jsx("path", ({d:"M16 16s-1.5-2-4-2-4 2-4 2"}))), (jsx("line", ({x1:"9",x2:"9.01",y1:"9",y2:"9"}))), (jsx("line", ({x1:"15",x2:"15.01",y1:"9",y2:"9"}))))), (jsx("h2", ({css:({ ["fontSize"] : "5vmin", ["fontWeight"] : "bold" })}), "An error occurred while rendering this page.")))), (jsx("p", ({css:({ ["opacity"] : "0.75", ["marginBlock"] : "1rem" })}), "This is an error with the application itself. Refreshing the page might help.")), (jsx("div", ({css:({ ["width"] : "100%", ["background"] : "color-mix(in srgb, currentColor 5%, transparent)", ["maxHeight"] : "15rem", ["overflow"] : "auto", ["borderRadius"] : "0.4rem" })}), (jsx("div", ({css:({ ["padding"] : "0.5rem" })}), (jsx("pre", ({css:({ ["wordBreak"] : "break-word", ["whiteSpace"] : "pre-wrap" })}), event_args.error.name + \': \' + event_args.error.message + \'\\n\' + event_args.error.stack)))))), (jsx("button", ({css:({ ["padding"] : "0.35rem 1.35rem", ["marginBlock"] : "0.5rem", ["marginInlineStart"] : "auto", ["background"] : "color-mix(in srgb, currentColor 15%, transparent)", ["borderRadius"] : "0.4rem", ["width"] : "fit-content", ["&:hover"] : ({ ["background"] : "color-mix(in srgb, currentColor 25%, transparent)" }), ["&:active"] : ({ ["background"] : "color-mix(in srgb, currentColor 35%, transparent)" }) }),onClick:((_e) => (addEvents([(ReflexEvent("_call_function", ({ ["function"] : (() => (navigator?.["clipboard"]?.["writeText"](event_args.error.name + \': \' + event_args.error.message + \'\\n\' + event_args.error.stack))), ["callback"] : null }), ({ })))], [_e], ({ }))))}), "Copy")), (jsx("hr", ({css:({ ["borderColor"] : "currentColor", ["opacity"] : "0.25" })}))), (jsx(ReactRouterLink, ({to:"https://reflex.dev"}), (jsx("div", ({css:({ ["display"] : "flex", ["alignItems"] : "baseline", ["justifyContent"] : "center", ["fontFamily"] : "monospace", ["--default-font-family"] : "monospace", ["gap"] : "0.5rem" })}), "Built with ", (jsx("svg", ({"aria-label":"Reflex",css:({ ["fill"] : "currentColor" }),height:"12",role:"img",width:"56",xmlns:"http://www.w3.org/2000/svg"}), (jsx("path", ({d:"M0 11.5999V0.399902H8.96V4.8799H6.72V2.6399H2.24V4.8799H6.72V7.1199H2.24V11.5999H0ZM6.72 11.5999V7.1199H8.96V11.5999H6.72Z"}))), (jsx("path", ({d:"M11.2 11.5999V0.399902H17.92V2.6399H13.44V4.8799H17.92V7.1199H13.44V9.3599H17.92V11.5999H11.2Z"}))), (jsx("path", ({d:"M20.16 11.5999V0.399902H26.88V2.6399H22.4V4.8799H26.88V7.1199H22.4V11.5999H20.16Z"}))), (jsx("path", ({d:"M29.12 11.5999V0.399902H31.36V9.3599H35.84V11.5999H29.12Z"}))), (jsx("path", ({d:"M38.08 11.5999V0.399902H44.8V2.6399H40.32V4.8799H44.8V7.1199H40.32V9.3599H44.8V11.5999H38.08Z"}))), (jsx("path", ({d:"M47.04 4.8799V0.399902H49.28V4.8799H47.04ZM53.76 4.8799V0.399902H56V4.8799H53.76ZM49.28 7.1199V4.8799H53.76V7.1199H49.28ZM47.04 11.5999V7.1199H49.28V11.5999H47.04ZM53.76 11.5999V7.1199H56V11.5999H53.76Z"}))), (jsx("title", ({}), "Reflex"))))))))))))),""" """onError:((_error, _info) => (addEvents([(ReflexEvent("reflex___state____state.reflex___state____frontend_event_exception_state.handle_frontend_exception", ({ ["info"] : ((((_error?.["name"]+": ")+_error?.["message"])+"\\n")+_error?.["stack"]), ["component_stack"] : _info?.["componentStack"] }), ({ })))], [_error, _info], ({ }))))""" - "}," - "jsx(RadixThemesColorModeProvider,{}," - "jsx(Fragment,{}," - "jsx(MemoizedToastProvider,{},)," - "jsx(RadixThemesTheme,{accentColor:\"plum\",css:{...theme.styles.global[':root'], ...theme.styles.global.body}}," - "jsx(Fragment,{}," - "jsx(DefaultOverlayComponents,{},)," - "jsx(Fragment,{}," - "children" + + "}," + + "jsx(RadixThemesColorModeProvider,{}," + + "jsx(Fragment,{}," + + "jsx(MemoizedToastProvider,{},)," + + "jsx(RadixThemesTheme,{accentColor:\"plum\",css:{...theme.styles.global[':root'], ...theme.styles.global.body}}," + + "jsx(Fragment,{}," + + "jsx(DefaultOverlayComponents,{},)," + + "jsx(Fragment,{}," + + "children" "))))))" + (")" if react_strict_mode else "") + ")" "\n}" ) @@ -1792,6 +1793,7 @@ def page(): "export function Layout" ) ].strip() + expected = ( "function AppWrap({children}) {\n" "const [addEvents, connectErrors] = useContext(EventLoopContext);\n\n\n\n" @@ -1801,16 +1803,16 @@ def page(): "jsx(ErrorBoundary,{" """fallbackRender:((event_args) => (jsx("div", ({css:({ ["height"] : "100%", ["width"] : "100%", ["position"] : "absolute", ["backgroundColor"] : "#fff", ["color"] : "#000", ["display"] : "flex", ["alignItems"] : "center", ["justifyContent"] : "center" })}), (jsx("div", ({css:({ ["display"] : "flex", ["flexDirection"] : "column", ["gap"] : "0.5rem", ["maxWidth"] : "min(80ch, 90vw)", ["borderRadius"] : "0.25rem", ["padding"] : "1rem" })}), (jsx("div", ({css:({ ["opacity"] : "0.5", ["display"] : "flex", ["gap"] : "4vmin", ["alignItems"] : "center" })}), (jsx("svg", ({className:"lucide lucide-frown-icon lucide-frown",fill:"none",stroke:"currentColor","stroke-linecap":"round","stroke-linejoin":"round","stroke-width":"2",viewBox:"0 0 24 24",width:"25vmin",xmlns:"http://www.w3.org/2000/svg"}), (jsx("circle", ({cx:"12",cy:"12",r:"10"}))), (jsx("path", ({d:"M16 16s-1.5-2-4-2-4 2-4 2"}))), (jsx("line", ({x1:"9",x2:"9.01",y1:"9",y2:"9"}))), (jsx("line", ({x1:"15",x2:"15.01",y1:"9",y2:"9"}))))), (jsx("h2", ({css:({ ["fontSize"] : "5vmin", ["fontWeight"] : "bold" })}), "An error occurred while rendering this page.")))), (jsx("p", ({css:({ ["opacity"] : "0.75", ["marginBlock"] : "1rem" })}), "This is an error with the application itself. Refreshing the page might help.")), (jsx("div", ({css:({ ["width"] : "100%", ["background"] : "color-mix(in srgb, currentColor 5%, transparent)", ["maxHeight"] : "15rem", ["overflow"] : "auto", ["borderRadius"] : "0.4rem" })}), (jsx("div", ({css:({ ["padding"] : "0.5rem" })}), (jsx("pre", ({css:({ ["wordBreak"] : "break-word", ["whiteSpace"] : "pre-wrap" })}), event_args.error.name + \': \' + event_args.error.message + \'\\n\' + event_args.error.stack)))))), (jsx("button", ({css:({ ["padding"] : "0.35rem 1.35rem", ["marginBlock"] : "0.5rem", ["marginInlineStart"] : "auto", ["background"] : "color-mix(in srgb, currentColor 15%, transparent)", ["borderRadius"] : "0.4rem", ["width"] : "fit-content", ["&:hover"] : ({ ["background"] : "color-mix(in srgb, currentColor 25%, transparent)" }), ["&:active"] : ({ ["background"] : "color-mix(in srgb, currentColor 35%, transparent)" }) }),onClick:((_e) => (addEvents([(ReflexEvent("_call_function", ({ ["function"] : (() => (navigator?.["clipboard"]?.["writeText"](event_args.error.name + \': \' + event_args.error.message + \'\\n\' + event_args.error.stack))), ["callback"] : null }), ({ })))], [_e], ({ }))))}), "Copy")), (jsx("hr", ({css:({ ["borderColor"] : "currentColor", ["opacity"] : "0.25" })}))), (jsx(ReactRouterLink, ({to:"https://reflex.dev"}), (jsx("div", ({css:({ ["display"] : "flex", ["alignItems"] : "baseline", ["justifyContent"] : "center", ["fontFamily"] : "monospace", ["--default-font-family"] : "monospace", ["gap"] : "0.5rem" })}), "Built with ", (jsx("svg", ({"aria-label":"Reflex",css:({ ["fill"] : "currentColor" }),height:"12",role:"img",width:"56",xmlns:"http://www.w3.org/2000/svg"}), (jsx("path", ({d:"M0 11.5999V0.399902H8.96V4.8799H6.72V2.6399H2.24V4.8799H6.72V7.1199H2.24V11.5999H0ZM6.72 11.5999V7.1199H8.96V11.5999H6.72Z"}))), (jsx("path", ({d:"M11.2 11.5999V0.399902H17.92V2.6399H13.44V4.8799H17.92V7.1199H13.44V9.3599H17.92V11.5999H11.2Z"}))), (jsx("path", ({d:"M20.16 11.5999V0.399902H26.88V2.6399H22.4V4.8799H26.88V7.1199H22.4V11.5999H20.16Z"}))), (jsx("path", ({d:"M29.12 11.5999V0.399902H31.36V9.3599H35.84V11.5999H29.12Z"}))), (jsx("path", ({d:"M38.08 11.5999V0.399902H44.8V2.6399H40.32V4.8799H44.8V7.1199H40.32V9.3599H44.8V11.5999H38.08Z"}))), (jsx("path", ({d:"M47.04 4.8799V0.399902H49.28V4.8799H47.04ZM53.76 4.8799V0.399902H56V4.8799H53.76ZM49.28 7.1199V4.8799H53.76V7.1199H49.28ZM47.04 11.5999V7.1199H49.28V11.5999H47.04ZM53.76 11.5999V7.1199H56V11.5999H53.76Z"}))), (jsx("title", ({}), "Reflex"))))))))))))),""" """onError:((_error, _info) => (addEvents([(ReflexEvent("reflex___state____state.reflex___state____frontend_event_exception_state.handle_frontend_exception", ({ ["info"] : ((((_error?.["name"]+": ")+_error?.["message"])+"\\n")+_error?.["stack"]), ["component_stack"] : _info?.["componentStack"] }), ({ })))], [_error, _info], ({ }))))""" - "}," - 'jsx(RadixThemesText,{as:"p"},' - "jsx(RadixThemesColorModeProvider,{}," - "jsx(Fragment,{}," - "jsx(MemoizedToastProvider,{},)," - "jsx(Fragment2,{}," - "jsx(Fragment,{}," - "jsx(DefaultOverlayComponents,{},)," - "jsx(Fragment,{}," - "children" + + "}," + + 'jsx(RadixThemesText,{as:"p"},' + + "jsx(RadixThemesColorModeProvider,{}," + + "jsx(Fragment,{}," + + "jsx(MemoizedToastProvider,{},)," + + "jsx(Fragment2,{}," + + "jsx(Fragment,{}," + + "jsx(DefaultOverlayComponents,{},)," + + "jsx(Fragment,{}," + + "children" ")))))))" + (")" if react_strict_mode else "") + "))\n}" ) assert expected.split(",") == function_app_definition.split(",") diff --git a/tests/units/test_event.py b/tests/units/test_event.py index c413a1f225e..50f75cabcf7 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -1,5 +1,6 @@ import json from collections.abc import Callable +from typing import Any, cast import pytest @@ -9,8 +10,10 @@ BACKGROUND_TASK_MARKER, Event, EventChain, + EventChainVar, EventHandler, EventSpec, + LambdaEventCallback, call_event_handler, event, fix_events, @@ -32,6 +35,12 @@ def make_var(value) -> Var: return Var(_js_expr=value) +def make_timeout_logger() -> EventChainVar: + return rx.vars.FunctionStringVar.create( + "(...args) => { setTimeout(() => console.log('Timeout reached!', args), 1000); }" + ).to(EventChain) + + def test_create_event(): """Test creating an event.""" event = Event(token="token", name="state.do_thing", payload={"arg": "value"}) @@ -473,6 +482,32 @@ def _args_spec(value: Var[int]) -> tuple[Var[int]]: ) +def test_event_chain_statement_block_preserves_nested_var_data(): + class S(BaseState): + x: Field[int] = field(0) + + @event + def s(self, value: int): + pass + + chain_var_data = Var.create( + EventChain( + events=[S.s(S.x), make_timeout_logger()], + args_spec=lambda: (), + ) + )._get_all_var_data() + + assert chain_var_data is not None + + x_var_data = S.x._get_all_var_data() + assert x_var_data is not None + + assert chain_var_data.state == x_var_data.state + assert chain_var_data.field_name == x_var_data.field_name + assert x_var_data.hooks[0] in chain_var_data.hooks + assert Hooks.EVENTS in chain_var_data.hooks + + def test_event_bound_method() -> None: class S(BaseState): @event @@ -668,6 +703,270 @@ def _args_spec() -> tuple: assert "to bool" in str(err.value) +def test_event_chain_create_allows_plain_function_var(): + """Plain FunctionVars should be usable as frontend event handlers.""" + frontend_handler = rx.vars.FunctionStringVar.create( + "(...args) => { setTimeout(() => console.log('Timeout reached!', args), 1000); }" + ) + + chain = EventChain.create(frontend_handler, args_spec=lambda: ()) + + assert isinstance(chain, EventChain) + assert len(chain.events) == 1 + chain_event = chain.events[0] + assert isinstance(chain_event, Var) + assert frontend_handler.equals(chain_event) + + +def test_event_chain_create_partials_function_var_with_non_empty_args_spec(): + """FunctionVars should receive trigger args as partial arguments.""" + frontend_handler = rx.vars.FunctionStringVar.create("(event) => console.log(event)") + + chain = EventChain.create(frontend_handler, args_spec=lambda e: [e]) + + assert isinstance(chain, EventChain) + assert len(chain.events) == 1 + chain_event = chain.events[0] + assert isinstance(chain_event, Var) + assert not frontend_handler.equals(chain_event) + assert "(_e, ...args)" in str(chain_event) + + +def test_event_chain_create_lambda_returned_function_var_keeps_original_signature(): + """FunctionVars returned from lambdas should not be partially applied.""" + frontend_handler = rx.vars.FunctionStringVar.create("(event) => console.log(event)") + + def return_function_var(e: Var[Any]) -> Any: + return frontend_handler + + chain = EventChain.create( + cast(LambdaEventCallback[Any], return_function_var), + args_spec=lambda e: [e], + ) + + assert isinstance(chain, EventChain) + assert len(chain.events) == 1 + chain_event = chain.events[0] + assert isinstance(chain_event, Var) + assert frontend_handler.equals(chain_event) + assert "(_e, ...args)" not in str(LiteralVar.create(chain)) + + +def test_event_chain_create_lambda_allows_mixed_event_sequences(): + """Lambdas should be able to return mixed event sequences.""" + + class MixedState(BaseState): + @event + def do_a_thing(self): + pass + + log_after_timeout = make_timeout_logger() + + def return_mixed_events(e: Var[Any]) -> Any: + return (MixedState.do_a_thing, log_after_timeout) + + chain = EventChain.create( + cast(LambdaEventCallback[Any], return_mixed_events), + args_spec=lambda e: [e], + ) + rendered = str(LiteralVar.create(chain)) + + assert isinstance(chain, EventChain) + assert "addEvents(" in rendered + assert "Timeout reached!" in rendered + assert rendered.index("addEvents(") < rendered.index("Timeout reached!") + + +def test_event_chain_create_lambda_preserves_explicit_event_chain(): + """Explicit EventChains returned from lambdas should be preserved.""" + inner = EventChain.create( + make_timeout_logger(), + args_spec=lambda: (), + event_actions={"preventDefault": True}, + ) + + def return_explicit_chain(e: Var[Any]) -> Any: + return inner + + chain = EventChain.create( + cast(LambdaEventCallback[Any], return_explicit_chain), + args_spec=lambda e: [e], + ) + + assert isinstance(chain, EventChain) + assert len(chain.events) == 1 + chain_event = chain.events[0] + assert isinstance(chain_event, Var) + assert chain_event.equals(Var.create(inner)) + + +def test_event_chain_create_wraps_plain_function_var_kwargs(): + """FunctionVars should compose with chain-level kwargs instead of bypassing wrapping.""" + frontend_handler = rx.vars.FunctionStringVar.create( + "(...args) => { setTimeout(() => console.log('Timeout reached!', args), 1000); }" + ) + + chain = EventChain.create( + frontend_handler, + args_spec=lambda: (), + event_actions={"preventDefault": True}, + ) + + assert isinstance(chain, EventChain) + assert len(chain.events) == 1 + chain_event = chain.events[0] + assert isinstance(chain_event, Var) + assert frontend_handler.equals(chain_event) + assert chain.event_actions == {"preventDefault": True} + + +def test_event_chain_create_wraps_event_chain_typed_function_var_kwargs(): + """FunctionVars cast to EventChain should still compose with chain-level kwargs.""" + frontend_handler = make_timeout_logger() + + chain = EventChain.create( + frontend_handler, + args_spec=lambda: (), + event_actions={"preventDefault": True}, + ) + + assert isinstance(chain, EventChain) + assert len(chain.events) == 1 + chain_event = chain.events[0] + assert isinstance(chain_event, Var) + assert frontend_handler.equals(chain_event) + assert chain.event_actions == {"preventDefault": True} + + +def test_event_chain_create_warns_for_event_chain_var_kwargs(): + """Prebuilt EventChainVars should also warn when extra kwargs are ignored.""" + prebuilt_chain = Var.create(EventChain(events=[], args_spec=lambda: ())) + + with pytest.warns(UserWarning, match="ignored for EventChainVar values"): + result = EventChain.create( + prebuilt_chain, + args_spec=lambda: (), + event_actions={"preventDefault": True}, + ) + + assert result is prebuilt_chain + + +def test_event_chain_create_allows_function_var_in_list(): + """FunctionVars should be allowed inside EventChain lists.""" + frontend_handler = make_timeout_logger() + + chain = EventChain.create([frontend_handler], args_spec=lambda: ()) + + assert isinstance(chain, EventChain) + assert len(chain.events) == 1 + chain_event = chain.events[0] + assert isinstance(chain_event, Var) + assert frontend_handler.equals(chain_event) + + +def test_button_accepts_mixed_event_handler_and_function_var(): + """Components should accept mixed backend/frontend event chains.""" + + class MixedState(BaseState): + @event + def do_a_thing(self): + pass + + log_after_timeout = make_timeout_logger() + + button = rx.button( + "Do both", + on_click=[MixedState.do_a_thing, log_after_timeout], + ) + + assert isinstance(button.event_triggers["on_click"], EventChain) + + +def test_event_chain_codegen_preserves_backend_event_actions_per_spec(): + """Backend-only chains should keep per-spec event actions separate.""" + + class FastPathState(BaseState): + @event + def do_a_thing(self, value: str): + pass + + chain = EventChain.create( + [ + FastPathState.do_a_thing("first x 1000").debounce(1000), + FastPathState.do_a_thing("second x 200").debounce(200), + ], + args_spec=lambda: (), + ) + rendered = str(LiteralVar.create(chain)) + + assert "applyEventActions(" not in rendered + assert rendered.count("addEvents(") == 2 + assert rendered.count('["debounce"] : 1000') == 1 + assert rendered.count('["debounce"] : 200') == 1 + assert rendered.index("first x 1000") < rendered.index("second x 200") + + +def test_event_chain_codegen_keeps_chain_event_actions_for_backend_only_events(): + """Chain-level actions should still wrap backend-only event chains.""" + + class FastPathState(BaseState): + @event + def do_a_thing(self, value: str): + pass + + chain = EventChain.create( + [ + FastPathState.do_a_thing("first x 1000").debounce(1000), + FastPathState.do_a_thing("second x 200").debounce(200), + ], + args_spec=lambda: (), + event_actions={"preventDefault": True}, + ) + rendered = str(LiteralVar.create(chain)) + + assert "applyEventActions(" in rendered + assert rendered.count("addEvents(") == 2 + assert rendered.count('["debounce"] : 1000') == 1 + assert rendered.count('["debounce"] : 200') == 1 + assert rendered.count('["preventDefault"] : true') == 1 + + +def test_event_chain_codegen_keeps_apply_event_actions_for_function_vars(): + """Frontend handlers should keep the applyEventActions wrapper.""" + log_after_timeout = make_timeout_logger() + + chain = EventChain.create( + log_after_timeout, + args_spec=lambda: (), + event_actions={"preventDefault": True}, + ) + rendered = str(LiteralVar.create(chain)) + + assert "applyEventActions(" in rendered + assert "Timeout reached!" in rendered + + +def test_event_chain_codegen_preserves_mixed_chain_order(): + """Mixed chains should keep backend and frontend work in the original order.""" + + class MixedState(BaseState): + @event + def do_a_thing(self): + pass + + log_after_timeout = make_timeout_logger() + chain = EventChain.create( + [MixedState.do_a_thing, log_after_timeout], + args_spec=lambda: (), + event_actions={"preventDefault": True}, + ) + rendered = str(LiteralVar.create(chain)) + + assert "applyEventActions(" in rendered + assert rendered.index("addEvents(") < rendered.index("Timeout reached!") + + def test_decentralized_event_with_args(): """Test the decentralized event.""" diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 402505c0543..4cc9c7a159d 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -951,6 +951,20 @@ def test_function_var(): ) assert str(explicit_return_func.call(1, 2)) == "(((a, b) => {return a + b})(1, 2))" + unwrapped_arrow_func = FunctionStringVar.create( + "(...args) => { const f = x => x + 1; return f(args); }" + ) + assert ( + str(unwrapped_arrow_func.call(1)) + == "(((...args) => { const f = x => x + 1; return f(args); })(1))" + ) + + nested_arrow_expr = FunctionStringVar.create("factory(() => 1)") + assert str(nested_arrow_expr.call()) == "(factory(() => 1)())" + + string_arrow_expr = FunctionStringVar.create('factory("=>")') + assert str(string_arrow_expr.call()) == '(factory("=>")())' + def test_var_operation(): @var_operation diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index c8f44f40236..48db442223c 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -20,6 +20,7 @@ from reflex.utils import format from reflex.utils.serializers import serialize_figure from reflex.vars.base import LiteralVar, Var +from reflex.vars.function import FunctionStringVar from reflex.vars.object import ObjectVar pytest.importorskip("pydantic") @@ -41,6 +42,107 @@ def mock_event(arg): pass +def mock_event_two(arg): + pass + + +def make_timeout_logger(): + return FunctionStringVar.create( + "(...args) => { setTimeout(() => console.log('Timeout reached!', args), 1000); }" + ).to(EventChain) + + +def test_format_prop_event_chain_pure_eventspec_grouped(): + """Pure EventSpec chains should preserve order with separate addEvents calls.""" + chain = EventChain( + events=[ + EventSpec(handler=EventHandler(fn=mock_event)), + EventSpec(handler=EventHandler(fn=mock_event_two)), + ], + args_spec=lambda e: [e], + ) + + assert format.format_prop(LiteralVar.create(chain)) == ( + '((_e) => {(addEvents([(ReflexEvent("mock_event", ({ }), ({ })))], ' + '[_e], ({ })));(addEvents([(ReflexEvent("mock_event_two", ({ }), ' + "({ })))], [_e], ({ })));})" + ) + + +def test_format_prop_event_chain_pure_function_var(): + """Pure FunctionVar chains should render as direct frontend calls.""" + log_after_timeout = make_timeout_logger() + chain = EventChain( + events=[log_after_timeout], + args_spec=lambda e: [e], + ) + + assert format.format_prop(LiteralVar.create(chain)) == ( + "((_e) => (((...args) => { setTimeout(() => console.log('Timeout reached!', " + "args), 1000); })(_e)))" + ) + + +def test_format_prop_event_chain_mixed_queue_and_function(): + """Mixed chains should alternate addEvents and direct calls in order.""" + log_after_timeout = make_timeout_logger() + chain = EventChain( + events=[ + EventSpec(handler=EventHandler(fn=mock_event)), + log_after_timeout, + EventSpec(handler=EventHandler(fn=mock_event_two)), + ], + args_spec=lambda e: [e], + ) + + assert format.format_prop(LiteralVar.create(chain)) == ( + '((_e) => {(addEvents([(ReflexEvent("mock_event", ({ }), ({ })))], ' + "[_e], ({ })));(((...args) => { setTimeout(() => console.log('Timeout reached!', " + 'args), 1000); })(_e));(addEvents([(ReflexEvent("mock_event_two", ' + "({ }), ({ })))], [_e], ({ })));})" + ) + + +def test_format_prop_event_chain_mixed_with_event_actions(): + """Mixed chains should preserve DOM event actions on the wrapper callback.""" + log_after_timeout = make_timeout_logger() + chain = EventChain( + events=[ + EventSpec(handler=EventHandler(fn=mock_event)), + log_after_timeout, + ], + args_spec=lambda e: [e], + event_actions={"preventDefault": True, "stopPropagation": True}, + ) + + assert format.format_prop(LiteralVar.create(chain)) == ( + '((_e) => (applyEventActions((() => {(addEvents([(ReflexEvent("mock_event", ' + "({ }), ({ })))], [_e], ({ })));(((...args) => { setTimeout(() => " + "console.log('Timeout reached!', args), 1000); })(_e));}), ({ " + '["preventDefault"] : true, ["stopPropagation"] : true }), _e)))' + ) + + +def test_format_prop_event_chain_mixed_with_queueable_event_actions(): + """Mixed chains should forward non-DOM event actions to queued backend groups.""" + log_after_timeout = make_timeout_logger() + chain = EventChain( + events=[ + EventSpec(handler=EventHandler(fn=mock_event)), + log_after_timeout, + ], + args_spec=lambda e: [e], + event_actions={"preventDefault": True, "throttle": 250}, + ) + + assert format.format_prop(LiteralVar.create(chain)) == ( + '((_e) => (applyEventActions((() => {(addEvents([(ReflexEvent("mock_event", ' + "({ }), ({ })))], [_e], ({ })));(((...args) => { setTimeout(() => " + "console.log('Timeout reached!', args), 1000); })(_e));}), ({ " + '["preventDefault"] : true, ["throttle"] : 250 }), _e)))' + ) + + @pytest.mark.parametrize( ("input", "output"), [ @@ -409,7 +511,7 @@ def test_format_match( args_spec=no_args_event_spec, event_actions={"stopPropagation": True}, ), - '((...args) => (addEvents([(ReflexEvent("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true }))))', + '((...args) => (applyEventActions((() => {(addEvents([(ReflexEvent("mock_event", ({ }), ({ })))], args, ({ })));}), ({ ["stopPropagation"] : true }), ...args)))', ), ( EventChain( @@ -429,7 +531,7 @@ def test_format_match( args_spec=no_args_event_spec, event_actions={"preventDefault": True}, ), - '((...args) => (addEvents([(ReflexEvent("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true }))))', + '((...args) => (applyEventActions((() => {(addEvents([(ReflexEvent("mock_event", ({ }), ({ })))], args, ({ })));}), ({ ["preventDefault"] : true }), ...args)))', ), ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'), (Var(_js_expr="var", _var_type=int).guess_type(), "var"),