Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/02_concepts/code/03_nested_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from apify_client import ApifyClientAsync
from apify_client._models_generated import ActorJobStatus
from apify_client._literals_generated import ActorJobStatus

TOKEN = 'MY-APIFY-TOKEN'

Expand Down
2 changes: 1 addition & 1 deletion docs/02_concepts/code/03_nested_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from apify_client import ApifyClient
from apify_client._models_generated import ActorJobStatus
from apify_client._literals_generated import ActorJobStatus

TOKEN = 'MY-APIFY-TOKEN'

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ indent-style = "space"
"**/__init__.py" = [
"F401", # Unused imports
]
"**/{_models,_models_generated}.py" = [
"TC001", # Pydantic needs the literal aliases importable at runtime to resolve forward references
]
"**/{scripts}/*" = [
"D", # Everything from the pydocstyle
"INP001", # File {filename} is part of an implicit namespace package, add an __init__.py
Expand Down
283 changes: 253 additions & 30 deletions scripts/postprocess_generated_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
- Fix discriminator field names that use camelCase instead of snake_case (known issue with
discriminators on schemas referenced from array items).
- Deduplicate the inlined `Type(StrEnum)` that comes from ErrorResponse.yaml; rewire to `ErrorType`.
- Rewrite every `class X(StrEnum)` as `X = Literal[...]` so downstream code can pass plain strings
(and reuse the named alias in resource-client signatures) instead of enum members.
- Convert camelCase string values in each literal alias to snake_case (Pythonic), and emit a
`_<NAME>_WIRE_VALUES` mapping the Python value back to the original camelCase form so the
resource clients can still produce the exact string the API expects on the wire.
- Move the resulting `X = Literal[...]` definitions into `_literals_generated.py`, leaving
`_models_generated.py` importing them — so consumers can depend on a dedicated literals module
without pulling in every Pydantic model.
- Add `@docs_group('Models')` to every model class (plus the required import).

Applied to `_typeddicts_generated.py`:
Expand All @@ -28,6 +36,7 @@
REPO_ROOT = Path(__file__).resolve().parent.parent
PACKAGE_DIR = REPO_ROOT / 'src' / 'apify_client'
MODELS_PATH = PACKAGE_DIR / '_models_generated.py'
LITERALS_PATH = PACKAGE_DIR / '_literals_generated.py'
TYPEDDICTS_PATH = PACKAGE_DIR / '_typeddicts_generated.py'

# Map of camelCase discriminator values to their snake_case equivalents.
Expand All @@ -51,6 +60,27 @@
)


def _collapse_blank_lines(content: str) -> str:
"""Collapse runs of 3+ blank lines down to exactly 3, leaving at most 2 blank lines between symbols."""
return re.sub(r'\n{3,}', '\n\n\n', content)


def _ensure_typing_import(content: str, name: str) -> str:
"""Append `name` to the `from typing import ...` line if not already imported.

Assumes the single-line import form datamodel-codegen emits; ruff re-wraps afterwards.
"""
typing_import = re.search(r'from typing import[^\n]+', content)
if typing_import is None or name in typing_import.group(0):
return content
return re.sub(
r'(from typing import )([^\n]+)',
lambda m: f'{m.group(1)}{m.group(2)}, {name}',
content,
count=1,
)


def fix_discriminators(content: str) -> str:
"""Replace camelCase discriminator values with their snake_case equivalents."""
for camel, snake in DISCRIMINATOR_FIXES.items():
Expand All @@ -73,8 +103,198 @@ def deduplicate_error_type_enum(content: str) -> str:
)
# Replace standalone `Type` references in annotation contexts (`: Type`, `| Type`, `[Type`).
content = re.sub(r'(?<=: )Type\b|(?<=\| )Type\b|(?<=\[)Type\b', 'ErrorType', content)
# Collapse triple+ blank lines left by the removal.
return re.sub(r'\n{3,}', '\n\n\n', content)
return _collapse_blank_lines(content)


def convert_enums_to_literals(content: str) -> str:
"""Rewrite every `class X(StrEnum): ...` into an `X = Literal[...]` alias.

Each member assignment (`NAME = 'value'`) contributes its string value to the literal in
declaration order. The class docstring, if present, is preserved as a trailing bare-string
docstring after the alias — matching the field-doc convention datamodel-codegen already uses
elsewhere in the generated file.

Runs before `add_docs_group_decorators`, so the enum classes have no `@docs_group` decorator
to strip. The `from enum import StrEnum` import is left alone and removed by ruff's F401 fix.
"""
tree = ast.parse(content)
lines = content.split('\n')
replacements: list[tuple[int, int, list[str]]] = []

for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
base_names = {b.id for b in node.bases if isinstance(b, ast.Name)}
if 'StrEnum' not in base_names:
continue

values: list[str] = [
stmt.value.value
for stmt in node.body
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
]
docstring = ast.get_docstring(node)

new_lines: list[str] = [f'{node.name} = Literal[']
new_lines.extend(f' {v!r},' for v in values)
new_lines.append(']')
if docstring is not None:
if '\n' in docstring:
new_lines.append('"""')
new_lines.extend(docstring.splitlines())
new_lines.append('"""')
else:
new_lines.append(f'"""{docstring}"""')

assert node.end_lineno is not None # noqa: S101
replacements.append((node.lineno - 1, node.end_lineno, new_lines))

if not replacements:
return content

# Replace in reverse order so earlier slice indices stay valid after each splice.
for start, end, new in sorted(replacements, key=lambda r: r[0], reverse=True):
lines[start:end] = new

return _collapse_blank_lines('\n'.join(lines))


LITERALS_FILE_HEADER = """\
# generated by postprocess_generated_models

from __future__ import annotations

from typing import Literal


"""

_CAMEL_CASE_VALUE = re.compile(r"^'([a-z][a-zA-Z0-9]*[A-Z][a-zA-Z0-9]*)',?$")


def _camel_to_snake(value: str) -> str:
"""Convert a camelCase identifier to snake_case."""
return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', value).lower()


def _is_literal_alias(node: ast.stmt) -> bool:
"""Return True if `node` is a top-level `Name = Literal[...]` statement."""
return (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and isinstance(node.value, ast.Subscript)
and isinstance(node.value.value, ast.Name)
and node.value.value.id == 'Literal'
)


def snake_case_camelcase_literal_values(content: str) -> str:
"""Rewrite camelCase string values in `Literal[...]` aliases into snake_case.

Scans each `Name = Literal[...]` block and, for any value matching the camelCase pattern
(lowercase-first followed by an uppercase letter), converts it to snake_case. For each alias
that had at least one conversion, emits a `_<NAME>_WIRE_VALUES: dict[<NAME>, str] = ...`
mapping right after the alias so consumers can translate back to the API wire format.

SCREAMING_SNAKE_CASE, dotted, hyphenated, and HTTP-method values pass through unchanged.
"""
tree = ast.parse(content)
lines = content.split('\n')
insertions: list[tuple[int, list[str]]] = [] # (insert-after-line-exclusive, lines to insert)

for alias_name, node, end_line in _extract_top_level_symbols(tree):
if not _is_literal_alias(node):
continue

assert node.end_lineno is not None # noqa: S101
wire_mapping: dict[str, str] = {}
for line_idx in range(node.lineno - 1, node.end_lineno):
match = _CAMEL_CASE_VALUE.match(lines[line_idx].strip())
if match is None:
continue
original = match.group(1)
snake = _camel_to_snake(original)
wire_mapping[snake] = original
lines[line_idx] = lines[line_idx].replace(f"'{original}'", f"'{snake}'", 1)

if not wire_mapping:
continue

constant_name = '_' + _camel_to_snake(alias_name).upper() + '_WIRE_VALUES'
docstring = f'"""Maps snake_case `{alias_name}` values to the camelCase form expected on the API wire."""'
mapping_lines = [
'',
f'{constant_name}: dict[{alias_name}, str] = {{',
*(f" '{snake}': '{original}'," for snake, original in wire_mapping.items()),
'}',
docstring,
]
# Insert after the alias's trailing docstring (absorbed into end_line) so the docstring
# stays attached to the alias rather than to the mapping dict.
insertions.append((end_line, mapping_lines))

if not insertions:
return content

for insert_at, new_lines in sorted(insertions, key=lambda r: r[0], reverse=True):
lines[insert_at:insert_at] = new_lines

return '\n'.join(lines)


def split_literals_to_file(content: str) -> tuple[str, str]:
"""Move every top-level `Name = Literal[...]` block into a separate literals module.

Walks the top-level AST, collects each literal alias plus its trailing bare-string docstring,
deletes them from `_models_generated.py`, and rebuilds `_literals_generated.py` from the blocks
in original order. The models content gains a `from apify_client._literals_generated import ...`
line so Pydantic can still resolve the forward references in field annotations.

Returns `(new_models_content, literals_file_content)`. If no literal aliases are found, the
models content is returned unchanged and the literals content is empty.
"""
tree = ast.parse(content)
lines = content.split('\n')

blocks: list[tuple[int, int, str]] = [
(node.lineno - 1, end_line, name)
for name, node, end_line in _extract_top_level_symbols(tree)
if _is_literal_alias(node)
]

if not blocks:
return content, ''

literal_lines: list[str] = []
for start, end, _ in blocks:
literal_lines.extend(lines[start:end])
literal_lines.append('')
literal_lines.append('')

new_lines = lines[:]
for start, end, _ in sorted(blocks, key=lambda b: b[0], reverse=True):
del new_lines[start:end]

# Inject the import right after the last existing `from apify_client.` import so ruff/isort
# keep the final ordering stable.
names = sorted(name for _, _, name in blocks)
import_line = f'from apify_client._literals_generated import {", ".join(names)}'
insert_at = next(
(idx + 1 for idx in range(len(new_lines) - 1, -1, -1) if new_lines[idx].startswith('from apify_client.')),
None,
)
if insert_at is None:
raise RuntimeError('No `from apify_client.` import found in generated models to anchor literals import')
new_lines.insert(insert_at, import_line)

models_content = _collapse_blank_lines('\n'.join(new_lines))
literals_content = _collapse_blank_lines(LITERALS_FILE_HEADER + '\n'.join(literal_lines))
return models_content, literals_content


def add_docs_group_decorators(content: str, group_name: GroupName) -> str:
Expand Down Expand Up @@ -136,17 +356,7 @@ def flatten_empty_typeddicts(content: str) -> str:
if not replaced:
return content

output = re.sub(r'\n{3,}', '\n\n\n', '\n'.join(lines))
# Flattening introduces new `TypeAlias` uses; make sure it's imported from typing.
typing_import = re.search(r'from typing import[^\n]+', output)
if typing_import is not None and 'TypeAlias' not in typing_import.group(0):
output = re.sub(
r'(from typing import )([^\n]+)',
lambda m: f'{m.group(1)}{m.group(2)}, TypeAlias',
output,
count=1,
)
return output
return _ensure_typing_import(_collapse_blank_lines('\n'.join(lines)), 'TypeAlias')


def _is_string_expr(node: ast.stmt) -> bool:
Expand All @@ -159,7 +369,7 @@ def _extract_top_level_symbols(tree: ast.Module) -> list[tuple[str, ast.stmt, in

If a top-level string expression immediately follows a symbol, it is absorbed into that
symbol's `end_line` so they get pruned together (datamodel-codegen emits the schema description
for TypeAlias statements as a bare string right after the alias).
for type-alias statements as a bare string right after the alias).
"""
symbols: list[tuple[str, ast.stmt, int]] = []
body = tree.body
Expand All @@ -171,6 +381,8 @@ def _extract_top_level_symbols(tree: ast.Module) -> list[tuple[str, ast.stmt, in
name = node.name
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
name = node.target.id
elif isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
name = node.targets[0].id

if name is not None:
assert node.end_lineno is not None # noqa: S101
Expand Down Expand Up @@ -237,10 +449,7 @@ def prune_typeddicts(content: str, seeds: frozenset[str]) -> tuple[str, set[str]
drop_line_indices.add(line_no)

pruned = [line for i, line in enumerate(lines) if i not in drop_line_indices]
output = '\n'.join(pruned)
# Collapse runs of blank lines left behind by deletions.
output = re.sub(r'\n{3,}', '\n\n\n', output)
return output, kept
return _collapse_blank_lines('\n'.join(pruned)), kept


def rename_with_dict_suffix(content: str, names: set[str]) -> str:
Expand All @@ -253,16 +462,30 @@ def rename_with_dict_suffix(content: str, names: set[str]) -> str:
return content


def postprocess_models(path: Path) -> bool:
"""Apply `_models_generated.py`-specific fixes. Returns True if the file changed."""
original = path.read_text()
def postprocess_models(models_path: Path, literals_path: Path) -> list[Path]:
"""Apply `_models_generated.py`-specific fixes and emit `_literals_generated.py`.

Returns the list of paths that were (re)written.
"""
original = models_path.read_text()
fixed = fix_discriminators(original)
fixed = deduplicate_error_type_enum(fixed)
fixed = convert_enums_to_literals(fixed)
fixed = add_docs_group_decorators(fixed, 'Models')
if fixed == original:
return False
path.write_text(fixed)
return True
models_content, literals_content = split_literals_to_file(fixed)
if literals_content:
literals_content = snake_case_camelcase_literal_values(literals_content)

changed: list[Path] = []
if models_content != original:
models_path.write_text(models_content)
changed.append(models_path)
if literals_content:
previous = literals_path.read_text() if literals_path.exists() else ''
if literals_content != previous:
literals_path.write_text(literals_content)
changed.append(literals_path)
return changed


def postprocess_typeddicts(path: Path) -> bool:
Expand All @@ -286,12 +509,12 @@ def run_ruff(paths: list[Path]) -> None:


def main() -> None:
changed: list[Path] = []
if postprocess_models(MODELS_PATH):
changed.append(MODELS_PATH)
print(f'Fixed generated models in {MODELS_PATH}')
changed = postprocess_models(MODELS_PATH, LITERALS_PATH)
if changed:
for path in changed:
print(f'Wrote {path}')
else:
print('No fixes needed for _models_generated.py')
print('No fixes needed for _models_generated.py / _literals_generated.py')

if postprocess_typeddicts(TYPEDDICTS_PATH):
changed.append(TYPEDDICTS_PATH)
Expand Down
2 changes: 1 addition & 1 deletion src/apify_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ImpitHttpClient,
ImpitHttpClientAsync,
)
from ._types import Timeout
from ._literals import Timeout

__version__ = metadata.version('apify-client')

Expand Down
Loading