Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
378 changes: 267 additions & 111 deletions packages/reflex-base/src/reflex_base/compiler/templates.py

Large diffs are not rendered by default.

150 changes: 150 additions & 0 deletions packages/reflex-base/src/reflex_base/compiler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Common utility functions used in the compiler."""

from typing import TypedDict

from reflex_base.utils import format, imports


def validate_imports(import_dict: imports.ParsedImportDict):
"""Verify that the same Tag is not used in multiple import.

Args:
import_dict: The dict of imports to validate

Raises:
ValueError: if a conflict on "tag/alias" is detected for an import.
"""
used_tags = {}
for lib, imported_items in import_dict.items():
for imported_item in imported_items:
import_name = (
f"{imported_item.tag}/{imported_item.alias}"
if imported_item.alias
else imported_item.tag
)
if import_name in used_tags:
already_imported = used_tags[import_name]
if (already_imported[0] == "$" and already_imported[1:] == lib) or (
lib[0] == "$" and lib[1:] == already_imported
):
used_tags[import_name] = lib if lib[0] == "$" else already_imported
continue
msg = f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
raise ValueError(msg)
if import_name is not None:
used_tags[import_name] = lib


def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement.

Args:
fields: The set of fields to import from the library.

Returns:
The libraries for default and rest.
default: default library. When install "import def from library".
rest: rest of libraries. When install "import {rest1, rest2} from library"

Raises:
ValueError: If there is more than one default import.
"""
# ignore the ImportVar fields with render=False during compilation
fields_set = {field for field in fields if field.render}

# Check for default imports.
defaults = {field for field in fields_set if field.is_default}
if len(defaults) >= 2:
msg = "Only one default import is allowed."
raise ValueError(msg)

# Get the default import, and the specific imports.
default = next(iter({field.name for field in defaults}), "")
rest = {field.name for field in fields_set - defaults}

return default, sorted(rest)


class ImportDict(TypedDict):
"""TypedDict for compiled import information.

Attributes:
lib: The library name.
default: The default import name.
rest: List of non-default import names.
"""

lib: str
default: str
rest: list[str]


def compile_imports(import_dict: imports.ParsedImportDict) -> list[ImportDict]:
"""Compile an import dict.

Args:
import_dict: The import dict to compile.

Returns:
The list of import dict.

Raises:
ValueError: If an import in the dict is invalid.
"""
collapsed_import_dict: imports.ParsedImportDict = imports.collapse_imports(
import_dict
)
validate_imports(collapsed_import_dict)
import_dicts: list[ImportDict] = []
for lib, fields in collapsed_import_dict.items():
# prevent lib from being rendered on the page if all imports are non rendered kind
if not any(f.render for f in fields):
continue

lib_paths: dict[str, list[imports.ImportVar]] = {}

for field in fields:
lib_paths.setdefault(field.package_path, []).append(field)

compiled = {
path: compile_import_statement(fields) for path, fields in lib_paths.items()
}

for path, (default, rest) in compiled.items():
if not lib:
if default:
msg = "No default field allowed for empty library."
raise ValueError(msg)
if rest is None or len(rest) == 0:
msg = "No fields to import."
raise ValueError(msg)
import_dicts.extend(get_import_dict(module) for module in sorted(rest))
continue

# remove the version before rendering the package imports
formatted_lib = format.format_library_name(lib) + (
path if path != "/" else ""
)

import_dicts.append(get_import_dict(formatted_lib, default, rest))
return import_dicts


def get_import_dict(
lib: str, default: str = "", rest: list[str] | None = None
) -> ImportDict:
"""Get dictionary for import template.

Args:
lib: The importing react library.
default: The default module to import.
rest: The rest module to import.

Returns:
A dictionary for import template.
"""
return ImportDict(
lib=lib,
default=default,
rest=rest or [],
)
5 changes: 3 additions & 2 deletions packages/reflex-base/src/reflex_base/components/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Union

from reflex_base import constants
from reflex_base.compiler.utils import compile_imports
from reflex_base.utils import imports
from reflex_base.utils.exceptions import DynamicComponentMissingLibraryError
from reflex_base.utils.format import format_library_name
Expand Down Expand Up @@ -78,7 +79,7 @@ def make_component(component: Component) -> str:
# Causes a circular import, so we import here.
from reflex_components_core.base.bare import Bare

from reflex.compiler import compiler, templates, utils
from reflex.compiler import compiler, templates

component = Bare.create(Var.create(component))

Expand Down Expand Up @@ -116,7 +117,7 @@ def make_component(component: Component) -> str:
imports[lib] = names

module_code_lines = templates.dynamic_components_module_template(
imports=utils.compile_imports(imports),
imports=compile_imports(imports),
memoized_code="\n".join(rendered_components),
).splitlines()

Expand Down
13 changes: 12 additions & 1 deletion packages/reflex-base/src/reflex_base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from reflex_base import constants
from reflex_base.constants.base import LogLevel
from reflex_base.constants.vite import ViteConfigDict
from reflex_base.environment import EnvironmentVariables as EnvironmentVariables
from reflex_base.environment import EnvVar as EnvVar
from reflex_base.environment import (
Expand Down Expand Up @@ -166,7 +167,7 @@ class BaseConfig:
bun_path: The bun path.
static_page_generation_timeout: Timeout to do a production build of a frontend page.
cors_allowed_origins: Comma separated list of origins that are allowed to connect to the backend API.
vite_allowed_hosts: Allowed hosts for the Vite dev server. Set to True to allow all hosts, or provide a list of hostnames (e.g. ["myservice.local"]) to allow specific ones. Prevents 403 errors in Docker, Codespaces, reverse proxies, etc.
vite_config: A user-defined Vite config that will get deeply merged with Reflex's Vite config, allowing for customization and overriding of Reflex defaults.
react_strict_mode: Whether to use React strict mode.
frontend_packages: Additional frontend packages to install.
state_manager_mode: Indicate which type of state manager to use.
Expand Down Expand Up @@ -222,6 +223,8 @@ class BaseConfig:

vite_allowed_hosts: bool | list[str] = False

vite_config: ViteConfigDict | None = None

react_strict_mode: bool = True

frontend_packages: list[str] = dataclasses.field(default_factory=list)
Expand Down Expand Up @@ -373,6 +376,14 @@ def _post_init(self, **kwargs):
removal_version="1.0",
)

if "vite_allowed_hosts" in kwargs and kwargs["vite_allowed_hosts"] is not False:
console.deprecate(
feature_name="vite_allowed_hosts",
reason="Use vite_config={'server': {'allowedHosts': ...}} instead.",
deprecation_version="0.9.3",
removal_version="1.0",
)

# Update default URLs if ports were set
kwargs.update(env_kwargs)
self._non_default_attributes = set(kwargs.keys())
Expand Down
Loading
Loading