diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2d5029..ec5b0d5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,3 +27,8 @@ repos: - id: flake8 additional_dependencies: - flake8-debugger == 4.1.2 + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v2.1.0 + hooks: + - id: mypy + pass_filenames: false diff --git a/CHANGES.rst b/CHANGES.rst index 537c5a8..c3df630 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,7 +4,8 @@ Changes 8.3 (unreleased) ---------------- -- Nothing changed yet. +- Add type annotations to the package code. + For clarification, restricted Python code does not support type annotations. 8.3a1.dev0 (2026-05-29) diff --git a/pyproject.toml b/pyproject.toml index f6b7394..502dd97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Security", + "Typing :: Typed", ] dynamic = ["readme"] requires-python = ">=3.10, <3.16" @@ -50,6 +51,10 @@ docs = [ "Sphinx", "furo", ] +typecheck = [ + "mypy", + "typeshed", +] [project.urls] Documentation = "https://restrictedpython.readthedocs.io/" @@ -83,3 +88,23 @@ directory = "parts/htmlcov" [tool.setuptools.dynamic] readme = {file = ["README.rst", "CHANGES.rst"]} +[tool.mypy] +mypy_path = "src" +packages = ["RestrictedPython"] +python_version = "3.10" +warn_unreachable = true +implicit_reexport = false +strict = true + +[[tool.mypy.overrides]] +module = ["DateTime"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["RestrictedPython.Guards"] +check_untyped_defs = false +disallow_untyped_defs = false + +[[tool.mypy.overrides]] +module = ["RestrictedPython.transformer"] +warn_no_return = false diff --git a/src/RestrictedPython/Eval.py b/src/RestrictedPython/Eval.py index 408b25a..a917964 100644 --- a/src/RestrictedPython/Eval.py +++ b/src/RestrictedPython/Eval.py @@ -13,8 +13,12 @@ """Restricted Python Expressions.""" import ast +import collections +import types +import typing -from .compile import compile_restricted_eval +from RestrictedPython._types import _cast_not_none +from RestrictedPython.compile import compile_restricted_eval nltosp = str.maketrans('\r\n', ' ') @@ -22,13 +26,21 @@ # No restrictions. default_guarded_getattr = getattr +_T = typing.TypeVar('_T') +_TK = typing.TypeVar('_TK', contravariant=True) +_TV = typing.TypeVar('_TV', covariant=True) -def default_guarded_getitem(ob, index): + +class _GetItem(typing.Protocol[_TK, _TV]): + def __getitem__(self, key: _TK) -> _TV: ... + + +def default_guarded_getitem(ob: _GetItem[_TK, _TV], index: _TK) -> _TV: # No restrictions. return ob[index] -def default_guarded_getiter(ob): +def default_guarded_getiter(ob: _T) -> _T: # No restrictions. return ob @@ -36,17 +48,18 @@ def default_guarded_getiter(ob): class RestrictionCapableEval: """A base class for restricted code.""" - globals = {'__builtins__': None} + globals: dict[str, typing.Any] = {'__builtins__': None} + # restricted - rcode = None + rcode: types.CodeType | None = None # unrestricted - ucode = None + ucode: types.CodeType | None = None # Names used by the expression - used = None + used: tuple[str, ...] | None = None - def __init__(self, expr): + def __init__(self, expr: str): """Create a restricted expression where: @@ -60,7 +73,7 @@ def __init__(self, expr): # Catch syntax errors. self.prepUnrestrictedCode() - def prepRestrictedCode(self): + def prepRestrictedCode(self) -> None: if self.rcode is None: result = compile_restricted_eval(self.expr, '') if result.errors: @@ -68,13 +81,12 @@ def prepRestrictedCode(self): self.used = tuple(result.used_names) self.rcode = result.code - def prepUnrestrictedCode(self): + def prepUnrestrictedCode(self) -> None: if self.ucode is None: - exp_node = compile( + exp_node = ast.parse( self.expr, '', - 'eval', - ast.PyCF_ONLY_AST) + 'eval') co = compile(exp_node, '', 'eval') @@ -90,7 +102,9 @@ def prepUnrestrictedCode(self): self.ucode = co - def eval(self, mapping): + def eval(self, + mapping: collections.abc.Mapping[str, + typing.Any]) -> typing.Any: # This default implementation is probably not very useful. :-( # This is meant to be overridden. self.prepRestrictedCode() @@ -103,11 +117,11 @@ def eval(self, mapping): global_scope.update(self.globals) - for name in self.used: + for name in _cast_not_none(self.used): if (name not in global_scope) and (name in mapping): global_scope[name] = mapping[name] - return eval(self.rcode, global_scope) + return eval(_cast_not_none(self.rcode), global_scope) - def __call__(self, **kw): + def __call__(self, **kw: typing.Any) -> typing.Any: return self.eval(kw) diff --git a/src/RestrictedPython/Guards.py b/src/RestrictedPython/Guards.py index d7c1b9c..eb9cc0f 100644 --- a/src/RestrictedPython/Guards.py +++ b/src/RestrictedPython/Guards.py @@ -220,7 +220,7 @@ def guard(ob): return guard -full_write_guard = _full_write_guard() +full_write_guard = _full_write_guard() # type: ignore[no-untyped-call] def guarded_setattr(object, name, value): diff --git a/src/RestrictedPython/Limits.py b/src/RestrictedPython/Limits.py index e133ec7..4c933ad 100644 --- a/src/RestrictedPython/Limits.py +++ b/src/RestrictedPython/Limits.py @@ -10,11 +10,28 @@ # FOR A PARTICULAR PURPOSE # ############################################################################## +import collections.abc +import typing -limited_builtins = {} +limited_builtins: dict[str, typing.Any] = {} -def limited_range(iFirst, *args): + +@typing.overload +def limited_range(iFirst: int) -> collections.abc.Sequence[int]: ... + + +@typing.overload +def limited_range(iStart: int, iEnd: int, / + ) -> collections.abc.Sequence[int]: ... + + +@typing.overload +def limited_range(iStart: int, iEnd: int, iStep: int, / + ) -> collections.abc.Sequence[int]: ... + + +def limited_range(iFirst: int, *args: int) -> collections.abc.Sequence[int]: # limited range function from Martijn Pieters RANGELIMIT = 1000 if not len(args): @@ -41,8 +58,10 @@ def limited_range(iFirst, *args): limited_builtins['range'] = limited_range +_T = typing.TypeVar('_T') + -def limited_list(seq): +def limited_list(seq: collections.abc.Iterable[_T]) -> list[_T]: if isinstance(seq, str): raise TypeError('cannot convert string to list') return list(seq) @@ -51,7 +70,7 @@ def limited_list(seq): limited_builtins['list'] = limited_list -def limited_tuple(seq): +def limited_tuple(seq: collections.abc.Iterable[_T]) -> tuple[_T, ...]: if isinstance(seq, str): raise TypeError('cannot convert string to tuple') return tuple(seq) diff --git a/src/RestrictedPython/PrintCollector.py b/src/RestrictedPython/PrintCollector.py index d28a7ab..0528e38 100644 --- a/src/RestrictedPython/PrintCollector.py +++ b/src/RestrictedPython/PrintCollector.py @@ -15,17 +15,17 @@ class PrintCollector: """Collect written text, and return it when called.""" - def __init__(self, _getattr_=None): + def __init__(self, _getattr_=None): # type: ignore[no-untyped-def] self.txt = [] self._getattr_ = _getattr_ - def write(self, text): + def write(self, text: str) -> None: self.txt.append(text) - def __call__(self): + def __call__(self) -> str: return ''.join(self.txt) - def _call_print(self, *objects, **kwargs): + def _call_print(self, *objects, **kwargs): # type: ignore[no-untyped-def] if kwargs.get('file', None) is None: kwargs['file'] = self else: diff --git a/src/RestrictedPython/Utilities.py b/src/RestrictedPython/Utilities.py index 26d73d1..6a26959 100644 --- a/src/RestrictedPython/Utilities.py +++ b/src/RestrictedPython/Utilities.py @@ -11,21 +11,24 @@ # ############################################################################## +import collections.abc import math import random import string +import types +import typing -utility_builtins = {} +utility_builtins: dict[str, typing.Any] = {} class _AttributeDelegator: - def __init__(self, mod, *excludes): + def __init__(self, mod: types.ModuleType, *excludes: str): """delegate attribute lookups outside *excludes* to module *mod*.""" self.__mod = mod self.__excludes = excludes - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> typing.Any: if attr in self.__excludes: raise NotImplementedError( f"{self.__mod.__name__}.{attr} is not safe") @@ -50,7 +53,7 @@ def __getattr__(self, attr): pass -def same_type(arg1, *args): +def same_type(arg1: object, *args: object) -> bool: """Compares the class or type of two or more objects.""" t = getattr(arg1, '__class__', type(arg1)) for arg in args: @@ -61,8 +64,10 @@ def same_type(arg1, *args): utility_builtins['same_type'] = same_type +_T = typing.TypeVar('_T') -def test(*args): + +def test(*args: _T) -> _T | None: length = len(args) for i in range(1, length, 2): if args[i - 1]: @@ -70,12 +75,22 @@ def test(*args): if length % 2: return args[-1] + return None utility_builtins['test'] = test +_TK = typing.TypeVar('_TK') +_TV = typing.TypeVar('_TV') +_T_in: typing.TypeAlias = collections.abc.Iterable[_TK | tuple[_TK, _TV]] +_T_out: typing.TypeAlias = list[tuple[_TK, _TK | _TV]] + -def reorder(s, with_=None, without=()): +def reorder( + s: _T_in[_TK, _TV], + with_: collections.abc.Iterable[typing.Any] | None = None, + without: collections.abc.Iterable[typing.Any] = () +) -> _T_out[_TK, _TV]: # s, with_, and without are sequences treated as sets. # The result is subtract(intersect(s, with_), without), # unless with_ is None, in which case it is subtract(s, without). diff --git a/src/RestrictedPython/_types.py b/src/RestrictedPython/_types.py new file mode 100644 index 0000000..9cc3c0f --- /dev/null +++ b/src/RestrictedPython/_types.py @@ -0,0 +1,8 @@ +import typing + + +_T = typing.TypeVar('_T') + + +def _cast_not_none(var: _T | None) -> _T: + return var # type: ignore[return-value] diff --git a/src/RestrictedPython/compile.py b/src/RestrictedPython/compile.py index 3253b8c..8b2d104 100644 --- a/src/RestrictedPython/compile.py +++ b/src/RestrictedPython/compile.py @@ -1,55 +1,77 @@ +from __future__ import annotations + import ast +import collections.abc +import os +import types +import typing import warnings -from collections import namedtuple from RestrictedPython._compat import IS_CPYTHON +from RestrictedPython._types import _cast_not_none from RestrictedPython.transformer import RestrictingNodeTransformer -CompileResult = namedtuple( - 'CompileResult', 'code, errors, warnings, used_names') +# Temporary workaround for missing _typeshed +ReadableBuffer: typing.TypeAlias = bytes | bytearray + + +class CompileResult(typing.NamedTuple): + code: types.CodeType | None + errors: collections.abc.Sequence[str] + warnings: collections.abc.Sequence[str] + used_names: collections.abc.Mapping[str, bool] + + syntax_error_template = ( - 'Line {lineno}: {type}: {msg} at statement: {statement!r}') + 'Line {lineno}: {type}: {msg} at statement: {statement!r}' +) NOT_CPYTHON_WARNING = ( 'RestrictedPython is only supported on CPython: use on other Python ' 'implementations may create security issues.' ) +_T_ast_compilable: typing.TypeAlias = ( + ast.Module | ast.Expression | ast.Interactive) +_T_source: typing.TypeAlias = str | ReadableBuffer | _T_ast_compilable + def _compile_restricted_mode( - source, - filename='', - mode="exec", - flags=0, - dont_inherit=False, - policy=RestrictingNodeTransformer): + source: _T_source, + filename: str | bytes | os.PathLike[typing.Any] = '', + mode: typing.Literal["exec", "eval", "single"] = "exec", + flags: int = 0, + dont_inherit: bool = False, + policy: type[ast.NodeTransformer] | None = RestrictingNodeTransformer, +) -> CompileResult: if not IS_CPYTHON: warnings.warn_explicit( NOT_CPYTHON_WARNING, RuntimeWarning, 'RestrictedPython', 0) byte_code = None - collected_errors = [] - collected_warnings = [] - used_names = {} + collected_errors: list[str] = [] + collected_warnings: list[str] = [] + used_names: dict[str, bool] = {} if policy is None: # Unrestricted Source Checks byte_code = compile(source, filename, mode=mode, flags=flags, dont_inherit=dont_inherit) elif issubclass(policy, RestrictingNodeTransformer): - c_ast = None allowed_source_types = [str, ast.Module] if not issubclass(type(source), tuple(allowed_source_types)): raise TypeError('Not allowed source type: ' '"{0.__class__.__name__}".'.format(source)) - c_ast = None + c_ast: _T_ast_compilable | None = None # workaround for pypy issue https://bitbucket.org/pypy/pypy/issues/2552 if isinstance(source, ast.Module): c_ast = source else: try: - c_ast = ast.parse(source, filename, mode) + c_ast = typing.cast( + _T_ast_compilable, ast.parse( + source, filename, mode)) except (TypeError, ValueError) as e: collected_errors.append(str(e)) except SyntaxError as v: @@ -78,11 +100,12 @@ def _compile_restricted_mode( def compile_restricted_exec( - source, - filename='', - flags=0, - dont_inherit=False, - policy=RestrictingNodeTransformer): + source: _T_source, + filename: str | bytes | os.PathLike[typing.Any] = '', + flags: int = 0, + dont_inherit: bool = False, + policy: type[ast.NodeTransformer] | None = RestrictingNodeTransformer, +) -> CompileResult: """Compile restricted for the mode `exec`.""" return _compile_restricted_mode( source, @@ -94,11 +117,12 @@ def compile_restricted_exec( def compile_restricted_eval( - source, - filename='', - flags=0, - dont_inherit=False, - policy=RestrictingNodeTransformer): + source: _T_source, + filename: str | bytes | os.PathLike[typing.Any] = '', + flags: int = 0, + dont_inherit: bool = False, + policy: type[ast.NodeTransformer] | None = RestrictingNodeTransformer, +) -> CompileResult: """Compile restricted for the mode `eval`.""" return _compile_restricted_mode( source, @@ -110,11 +134,12 @@ def compile_restricted_eval( def compile_restricted_single( - source, - filename='', - flags=0, - dont_inherit=False, - policy=RestrictingNodeTransformer): + source: _T_source, + filename: str | bytes | os.PathLike[typing.Any] = '', + flags: int = 0, + dont_inherit: bool = False, + policy: type[ast.NodeTransformer] | None = RestrictingNodeTransformer, +) -> CompileResult: """Compile restricted for the mode `single`.""" return _compile_restricted_mode( source, @@ -126,14 +151,16 @@ def compile_restricted_single( def compile_restricted_function( - p, # parameters - body, - name, - filename='', - globalize=None, # List of globals (e.g. ['here', 'context', ...]) - flags=0, - dont_inherit=False, - policy=RestrictingNodeTransformer): + p: str, # parameters + body: str | ReadableBuffer | ast.Module | ast.Interactive, + name: str, + filename: str | bytes | os.PathLike[typing.Any] = '', + # List of globals (e.g. ['here', 'context', ...]) + globalize: list[str] | None = None, + flags: int = 0, + dont_inherit: bool = False, + policy: type[ast.NodeTransformer] | None = RestrictingNodeTransformer, +) -> CompileResult: """Compile a restricted code object for a function. Documentation see: @@ -149,7 +176,7 @@ def compile_restricted_function( msg=v.msg, statement=v.text.strip() if v.text else None) return CompileResult( - code=None, errors=(error,), warnings=(), used_names=()) + code=None, errors=(error,), warnings=(), used_names={}) # The compiled code is actually executed inside a function # (that is called when the code is called) so reading and assigning to a @@ -166,7 +193,7 @@ def compile_restricted_function( assert isinstance(function_ast, ast.FunctionDef) function_ast.name = name - wrapper_ast.body[0].body = body_ast.body + function_ast.body = body_ast.body wrapper_ast = ast.fix_missing_locations(wrapper_ast) result = _compile_restricted_mode( @@ -181,12 +208,13 @@ def compile_restricted_function( def compile_restricted( - source, - filename='', - mode='exec', - flags=0, - dont_inherit=False, - policy=RestrictingNodeTransformer): + source: _T_source, + filename: str | bytes | os.PathLike[typing.Any] = '', + mode: str = 'exec', + flags: int = 0, + dont_inherit: bool = False, + policy: type[ast.NodeTransformer] | None = RestrictingNodeTransformer, +) -> types.CodeType: """Replacement for the built-in compile() function. policy ... `ast.NodeTransformer` class defining the restrictions. @@ -196,7 +224,8 @@ def compile_restricted( result = _compile_restricted_mode( source, filename=filename, - mode=mode, + mode=mode, # type: ignore[arg-type] + # https://github.com/zopefoundation/RestrictedPython/issues/318 flags=flags, dont_inherit=dont_inherit, policy=policy) @@ -209,4 +238,4 @@ def compile_restricted( ) if result.errors: raise SyntaxError(result.errors) - return result.code + return _cast_not_none(result.code) diff --git a/src/RestrictedPython/py.typed b/src/RestrictedPython/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/RestrictedPython/transformer.py b/src/RestrictedPython/transformer.py index b1ec1f6..f02f6ea 100644 --- a/src/RestrictedPython/transformer.py +++ b/src/RestrictedPython/transformer.py @@ -19,8 +19,11 @@ import ast +import collections import contextlib +import sys import textwrap +import typing # For AugAssign the operator must be converted to a string. @@ -111,11 +114,20 @@ "cr_origin", ]) +_T_visit_return: typing.TypeAlias = ast.AST | typing.Iterable[ast.AST] | None +_T_pos_ast: typing.TypeAlias = ( + ast.stmt | ast.expr | ast.excepthandler | ast.arg | ast.keyword | ast.alias + | ast.pattern) +if sys.version_info >= (3, 12): + _T_pos_ast: typing.TypeAlias = _T_pos_ast | ast.type_param +_T = typing.TypeVar('_T', bound=ast.AST) # When new ast nodes are generated they have no 'lineno', 'end_lineno', # 'col_offset' and 'end_col_offset'. This function copies these fields from the # incoming node: -def copy_locations(new_node, old_node): + + +def copy_locations(new_node: _T_pos_ast, old_node: _T_pos_ast) -> None: assert 'lineno' in new_node._attributes new_node.lineno = old_node.lineno @@ -132,12 +144,12 @@ def copy_locations(new_node, old_node): class PrintInfo: - def __init__(self): + def __init__(self) -> None: self.print_used = False self.printed_used = False @contextlib.contextmanager - def new_print_scope(self): + def new_print_scope(self) -> collections.abc.Iterator[None]: old_print_used = self.print_used old_printed_used = self.printed_used @@ -152,8 +164,14 @@ def new_print_scope(self): class RestrictingNodeTransformer(ast.NodeTransformer): - - def __init__(self, errors=None, warnings=None, used_names=None): + errors: list[str] + warnings: list[str] + used_names: dict[str, bool] + + def __init__(self, + errors: list[str] | None = None, + warnings: list[str] | None = None, + used_names: dict[str, bool] | None = None): super().__init__() self.errors = [] if errors is None else errors self.warnings = [] if warnings is None else warnings @@ -170,26 +188,26 @@ def __init__(self, errors=None, warnings=None, used_names=None): self.print_info = PrintInfo() - def gen_tmp_name(self): + def gen_tmp_name(self) -> str: # 'check_name' ensures that no variable is prefixed with '_'. # => Its safe to use '_tmp..' as a temporary variable. name = '_tmp%i' % self._tmp_idx self._tmp_idx += 1 return name - def error(self, node, info): + def error(self, node: ast.AST, info: str) -> None: """Record a security error discovered during transformation.""" lineno = getattr(node, 'lineno', None) self.errors.append( f'Line {lineno}: {info}') - def warn(self, node, info): - """Record a security error discovered during transformation.""" + def warn(self, node: ast.AST, info: str) -> None: + """Record a security warning discovered during transformation.""" lineno = getattr(node, 'lineno', None) self.warnings.append( f'Line {lineno}: {info}') - def guard_iter(self, node): + def guard_iter(self, node: ast.For | ast.comprehension) -> _T_visit_return: """ Converts: for x in expr @@ -220,10 +238,12 @@ def guard_iter(self, node): node.iter = new_iter return node - def is_starred(self, ob): + def is_starred(self, ob: ast.AST) -> typing.TypeGuard[ast.Starred]: + # TODO: Change Type Annotation to typing.TypeIs[ast.Starred] when + # Support for Python 3.12 is dropped. return isinstance(ob, ast.Starred) - def gen_unpack_spec(self, tpl): + def gen_unpack_spec(self, tpl: ast.Tuple) -> ast.Dict: """Generate a specification for 'guarded_unpack_sequence'. This spec is used to protect sequence unpacking. @@ -271,7 +291,8 @@ def gen_unpack_spec(self, tpl): spec = ast.Dict(keys=[], values=[]) spec.keys.append(ast.Constant('childs')) - spec.values.append(ast.Tuple([], ast.Load())) + val0 = ast.Tuple([], ast.Load()) + spec.values.append(val0) # starred elements in a sequence do not contribute into the min_len. # For example a, b, *c = g @@ -292,21 +313,26 @@ def gen_unpack_spec(self, tpl): el = ast.Tuple([], ast.Load()) el.elts.append(ast.Constant(idx - offset)) el.elts.append(self.gen_unpack_spec(val)) - spec.values[0].elts.append(el) + val0.elts.append(el) spec.keys.append(ast.Constant('min_len')) spec.values.append(ast.Constant(min_len)) return spec - def protect_unpack_sequence(self, target, value): + def protect_unpack_sequence( + self, + target: ast.Tuple, + value: ast.expr) -> ast.Call: spec = self.gen_unpack_spec(target) return ast.Call( func=ast.Name('_unpack_sequence_', ast.Load()), args=[value, spec, ast.Name('_getiter_', ast.Load())], keywords=[]) - def gen_unpack_wrapper(self, node, target): + def gen_unpack_wrapper(self, + node: ast.stmt, + target: ast.Tuple) -> tuple[ast.Name, ast.Try]: """Helper function to protect tuple unpacks. node: used to copy the locations for the new nodes. @@ -342,8 +368,9 @@ def gen_unpack_wrapper(self, node, target): # arg = converter # finally: # del tmp_arg - try_body = [ast.Assign(targets=[target], value=converter)] - finalbody = [self.gen_del_stmt(tmp_name)] + try_body: list[ast.stmt] = [ast.Assign( + targets=[target], value=converter)] + finalbody: list[ast.stmt] = [self.gen_del_stmt(tmp_name)] cleanup = ast.Try( body=try_body, finalbody=finalbody, handlers=[], orelse=[]) @@ -355,13 +382,17 @@ def gen_unpack_wrapper(self, node, target): return (tmp_target, cleanup) - def gen_none_node(self): + def gen_none_node(self) -> ast.Constant: return ast.Constant(None) - def gen_del_stmt(self, name_to_del): + def gen_del_stmt(self, name_to_del: str) -> ast.Delete: return ast.Delete(targets=[ast.Name(name_to_del, ast.Del())]) - def check_name(self, node, name, allow_magic_methods=False): + def check_name( + self, + node: _T_pos_ast, + name: str | None, + allow_magic_methods: bool = False) -> None: """Check names if they are allowed. If ``allow_magic_methods is True`` names in `ALLOWED_FUNC_NAMES` @@ -386,7 +417,9 @@ def check_name(self, node, name, allow_magic_methods=False): elif name in FORBIDDEN_FUNC_NAMES: self.error(node, f'"{name}" is a reserved name.') - def check_function_argument_names(self, node): + def check_function_argument_names( + self, + node: ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda) -> None: for arg in node.args.args: self.check_name(node, arg.arg) @@ -399,7 +432,7 @@ def check_function_argument_names(self, node): for arg in node.args.kwonlyargs: self.check_name(node, arg.arg) - def check_import_names(self, node): + def check_import_names(self, node: ast.ImportFrom | ast.Import) -> ast.AST: """Check the names being imported. This is a protection against rebinding dunder names like @@ -416,7 +449,10 @@ def check_import_names(self, node): return self.node_contents_visit(node) - def inject_print_collector(self, node, position=0): + def inject_print_collector( + self, + node: ast.Module | ast.FunctionDef, + position: int = 0) -> None: print_used = self.print_info.print_used printed_used = self.print_info.printed_used @@ -449,7 +485,8 @@ def inject_print_collector(self, node, position=0): # Special Functions for an ast.NodeTransformer - def generic_visit(self, node): + def generic_visit(self, # type: ignore[override] + node: ast.AST) -> _T_visit_return: """Reject ast nodes which do not have a corresponding `visit_` method. This is needed to prevent new ast nodes from new Python versions to be @@ -464,18 +501,18 @@ def generic_visit(self, node): ) self.not_allowed(node) - def not_allowed(self, node): + def not_allowed(self, node: ast.AST) -> None: self.error( node, f'{node.__class__.__name__} statements are not allowed.') - def node_contents_visit(self, node): + def node_contents_visit(self, node: _T) -> _T: """Visit the contents of a node.""" - return super().generic_visit(node) + return super().generic_visit(node) # type: ignore[return-value] # ast for Literals - def visit_Constant(self, node): + def visit_Constant(self, node: ast.Constant) -> _T_visit_return: """Allow constant literals. Constant replaces Num, Str, Bytes, NameConstant and Ellipsis in @@ -484,41 +521,46 @@ def visit_Constant(self, node): """ return self.node_contents_visit(node) - def visit_Interactive(self, node): + def visit_Interactive(self, node: ast.Interactive) -> _T_visit_return: """Allow single mode without restrictions.""" return self.node_contents_visit(node) - def visit_List(self, node): + def visit_List(self, node: ast.List) -> _T_visit_return: """Allow list literals without restrictions.""" return self.node_contents_visit(node) - def visit_Tuple(self, node): + def visit_Tuple(self, node: ast.Tuple) -> _T_visit_return: """Allow tuple literals without restrictions.""" return self.node_contents_visit(node) - def visit_Set(self, node): + def visit_Set(self, node: ast.Set) -> _T_visit_return: """Allow set literals without restrictions.""" return self.node_contents_visit(node) - def visit_Dict(self, node): + def visit_Dict(self, node: ast.Dict) -> _T_visit_return: """Allow dict literals without restrictions.""" return self.node_contents_visit(node) - def visit_FormattedValue(self, node): + def visit_FormattedValue( + self, + node: ast.FormattedValue) -> _T_visit_return: """Allow f-strings without restrictions.""" return self.node_contents_visit(node) - def visit_TemplateStr(self, node): + def visit_TemplateStr(self, node: ast.AST) -> _T_visit_return: """Template strings are allowed by default. As Template strings are a very basic template mechanism, that needs additional rendering logic to be useful, they are not blocked by default. Those rendering logic would be affected by RestrictedPython as well. + + TODO: Change Type Annotation to ast.TemplateStr when + Support for Python 3.13 is dropped. """ return self.node_contents_visit(node) - def visit_Interpolation(self, node): + def visit_Interpolation(self, node: ast.AST) -> _T_visit_return: """Interpolations are allowed by default. As Interpolations are part of Template Strings, they are needed @@ -526,16 +568,19 @@ def visit_Interpolation(self, node): are allowed. As a user has to provide additional rendering logic to make use of Template Strings, the security implications of Interpolations are limited in the context of RestrictedPython. + + TODO: Change Type Annotation to ast.Interpolation when + Support for Python 3.13 is dropped. """ return self.node_contents_visit(node) - def visit_JoinedStr(self, node): + def visit_JoinedStr(self, node: ast.JoinedStr) -> _T_visit_return: """Allow joined string without restrictions.""" return self.node_contents_visit(node) # ast for Variables - def visit_Name(self, node): + def visit_Name(self, node: ast.Name) -> _T_visit_return: """Prevents access to protected names. Converts use of the name 'printed' to this expression: '_print()' @@ -544,6 +589,7 @@ def visit_Name(self, node): node = self.node_contents_visit(node) if isinstance(node.ctx, ast.Load): + new_node: _T_pos_ast if node.id == 'printed': self.print_info.printed_used = True new_node = ast.Call( @@ -569,25 +615,25 @@ def visit_Name(self, node): self.check_name(node, node.id) return node - def visit_Load(self, node): + def visit_Load(self, node: ast.Load) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_Store(self, node): + def visit_Store(self, node: ast.Store) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_Del(self, node): + def visit_Del(self, node: ast.Del) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_Starred(self, node): + def visit_Starred(self, node: ast.Starred) -> _T_visit_return: """ """ @@ -595,18 +641,18 @@ def visit_Starred(self, node): # Expressions - def visit_Expression(self, node): + def visit_Expression(self, node: ast.Expression) -> _T_visit_return: """Allow Expression statements without restrictions. They are in the AST when using the `eval` compile mode. """ return self.node_contents_visit(node) - def visit_Expr(self, node): + def visit_Expr(self, node: ast.Expr) -> _T_visit_return: """Allow Expr statements (any expression) without restrictions.""" return self.node_contents_visit(node) - def visit_UnaryOp(self, node): + def visit_UnaryOp(self, node: ast.UnaryOp) -> _T_visit_return: """ UnaryOp (Unary Operations) is the overall element for: * Not --> which should be allowed @@ -615,135 +661,135 @@ def visit_UnaryOp(self, node): """ return self.node_contents_visit(node) - def visit_UAdd(self, node): + def visit_UAdd(self, node: ast.UAdd) -> _T_visit_return: """Allow positive notation of variables. (e.g. +var)""" return self.node_contents_visit(node) - def visit_USub(self, node): + def visit_USub(self, node: ast.USub) -> _T_visit_return: """Allow negative notation of variables. (e.g. -var)""" return self.node_contents_visit(node) - def visit_Not(self, node): + def visit_Not(self, node: ast.Not) -> _T_visit_return: """Allow the `not` operator.""" return self.node_contents_visit(node) - def visit_Invert(self, node): + def visit_Invert(self, node: ast.Invert) -> _T_visit_return: """Allow `~` expressions.""" return self.node_contents_visit(node) - def visit_BinOp(self, node): + def visit_BinOp(self, node: ast.BinOp) -> _T_visit_return: """Allow binary operations.""" return self.node_contents_visit(node) - def visit_Add(self, node): + def visit_Add(self, node: ast.Add) -> _T_visit_return: """Allow `+` expressions.""" return self.node_contents_visit(node) - def visit_Sub(self, node): + def visit_Sub(self, node: ast.Sub) -> _T_visit_return: """Allow `-` expressions.""" return self.node_contents_visit(node) - def visit_Mult(self, node): + def visit_Mult(self, node: ast.Mult) -> _T_visit_return: """Allow `*` expressions.""" return self.node_contents_visit(node) - def visit_Div(self, node): + def visit_Div(self, node: ast.Div) -> _T_visit_return: """Allow `/` expressions.""" return self.node_contents_visit(node) - def visit_FloorDiv(self, node): + def visit_FloorDiv(self, node: ast.FloorDiv) -> _T_visit_return: """Allow `//` expressions.""" return self.node_contents_visit(node) - def visit_Mod(self, node): + def visit_Mod(self, node: ast.Mod) -> _T_visit_return: """Allow `%` expressions.""" return self.node_contents_visit(node) - def visit_Pow(self, node): + def visit_Pow(self, node: ast.Pow) -> _T_visit_return: """Allow `**` expressions.""" return self.node_contents_visit(node) - def visit_LShift(self, node): + def visit_LShift(self, node: ast.LShift) -> _T_visit_return: """Allow `<<` expressions.""" return self.node_contents_visit(node) - def visit_RShift(self, node): + def visit_RShift(self, node: ast.RShift) -> _T_visit_return: """Allow `>>` expressions.""" return self.node_contents_visit(node) - def visit_BitOr(self, node): + def visit_BitOr(self, node: ast.BitOr) -> _T_visit_return: """Allow `|` expressions.""" return self.node_contents_visit(node) - def visit_BitXor(self, node): + def visit_BitXor(self, node: ast.BitXor) -> _T_visit_return: """Allow `^` expressions.""" return self.node_contents_visit(node) - def visit_BitAnd(self, node): + def visit_BitAnd(self, node: ast.BitAnd) -> _T_visit_return: """Allow `&` expressions.""" return self.node_contents_visit(node) - def visit_MatMult(self, node): + def visit_MatMult(self, node: ast.MatMult) -> _T_visit_return: """Allow multiplication (`@`).""" return self.node_contents_visit(node) - def visit_BoolOp(self, node): + def visit_BoolOp(self, node: ast.BoolOp) -> _T_visit_return: """Allow bool operator without restrictions.""" return self.node_contents_visit(node) - def visit_And(self, node): + def visit_And(self, node: ast.And) -> _T_visit_return: """Allow bool operator `and` without restrictions.""" return self.node_contents_visit(node) - def visit_Or(self, node): + def visit_Or(self, node: ast.Or) -> _T_visit_return: """Allow bool operator `or` without restrictions.""" return self.node_contents_visit(node) - def visit_Compare(self, node): + def visit_Compare(self, node: ast.Compare) -> _T_visit_return: """Allow comparison expressions without restrictions.""" return self.node_contents_visit(node) - def visit_Eq(self, node): + def visit_Eq(self, node: ast.Eq) -> _T_visit_return: """Allow == expressions.""" return self.node_contents_visit(node) - def visit_NotEq(self, node): + def visit_NotEq(self, node: ast.NotEq) -> _T_visit_return: """Allow != expressions.""" return self.node_contents_visit(node) - def visit_Lt(self, node): + def visit_Lt(self, node: ast.Lt) -> _T_visit_return: """Allow < expressions.""" return self.node_contents_visit(node) - def visit_LtE(self, node): + def visit_LtE(self, node: ast.LtE) -> _T_visit_return: """Allow <= expressions.""" return self.node_contents_visit(node) - def visit_Gt(self, node): + def visit_Gt(self, node: ast.Gt) -> _T_visit_return: """Allow > expressions.""" return self.node_contents_visit(node) - def visit_GtE(self, node): + def visit_GtE(self, node: ast.GtE) -> _T_visit_return: """Allow >= expressions.""" return self.node_contents_visit(node) - def visit_Is(self, node): + def visit_Is(self, node: ast.Is) -> _T_visit_return: """Allow `is` expressions.""" return self.node_contents_visit(node) - def visit_IsNot(self, node): + def visit_IsNot(self, node: ast.IsNot) -> _T_visit_return: """Allow `is not` expressions.""" return self.node_contents_visit(node) - def visit_In(self, node): + def visit_In(self, node: ast.In) -> _T_visit_return: """Allow `in` expressions.""" return self.node_contents_visit(node) - def visit_NotIn(self, node): + def visit_NotIn(self, node: ast.NotIn) -> _T_visit_return: """Allow `not in` expressions.""" return self.node_contents_visit(node) - def visit_Call(self, node): + def visit_Call(self, node: ast.Call) -> _T_visit_return: """Checks calls with '*args' and '**kwargs'. Note: The following happens only if '*args' or '**kwargs' is used. @@ -785,17 +831,17 @@ def visit_Call(self, node): copy_locations(node.func, node.args[0]) return node - def visit_keyword(self, node): + def visit_keyword(self, node: ast.keyword) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_IfExp(self, node): + def visit_IfExp(self, node: ast.IfExp) -> _T_visit_return: """Allow `if` expressions without restrictions.""" return self.node_contents_visit(node) - def visit_Attribute(self, node): + def visit_Attribute(self, node: ast.Attribute) -> _T_visit_return: """Checks and mutates attribute access/assignment. 'a.b' becomes '_getattr_(a, "b")' @@ -851,7 +897,7 @@ def visit_Attribute(self, node): # Subscripting - def visit_Subscript(self, node): + def visit_Subscript(self, node: ast.Subscript) -> _T_visit_return: """Transforms all kinds of subscripts. 'foo[bar]' becomes '_getitem_(foo, bar)' @@ -896,7 +942,7 @@ def visit_Subscript(self, node): raise NotImplementedError( f"Unknown ctx type: {type(node.ctx)}") - def visit_Slice(self, node): + def visit_Slice(self, node: ast.Slice) -> _T_visit_return: """ """ @@ -904,31 +950,31 @@ def visit_Slice(self, node): # Comprehensions - def visit_ListComp(self, node): + def visit_ListComp(self, node: ast.ListComp) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_SetComp(self, node): + def visit_SetComp(self, node: ast.SetComp) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_GeneratorExp(self, node): + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_DictComp(self, node): + def visit_DictComp(self, node: ast.DictComp) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_comprehension(self, node): + def visit_comprehension(self, node: ast.comprehension) -> _T_visit_return: """ """ @@ -936,7 +982,7 @@ def visit_comprehension(self, node): # Statements - def visit_Assign(self, node): + def visit_Assign(self, node: ast.Assign) -> _T_visit_return: """ """ @@ -985,7 +1031,7 @@ def visit_Assign(self, node): return new_nodes - def visit_AugAssign(self, node): + def visit_AugAssign(self, node: ast.AugAssign) -> _T_visit_return: """Forbid certain kinds of AugAssign According to the language reference (and ast.c) the following nodes @@ -1036,75 +1082,79 @@ def visit_AugAssign(self, node): raise NotImplementedError( f"Unknown target type: {type(node.target)}") - def visit_Raise(self, node): + def visit_Raise(self, node: ast.Raise) -> _T_visit_return: """Allow `raise` statements without restrictions.""" return self.node_contents_visit(node) - def visit_Assert(self, node): + def visit_Assert(self, node: ast.Assert) -> _T_visit_return: """Allow assert statements without restrictions.""" return self.node_contents_visit(node) - def visit_Delete(self, node): + def visit_Delete(self, node: ast.Delete) -> _T_visit_return: """Allow `del` statements without restrictions.""" return self.node_contents_visit(node) - def visit_Pass(self, node): + def visit_Pass(self, node: ast.Pass) -> _T_visit_return: """Allow `pass` statements without restrictions.""" return self.node_contents_visit(node) # Imports - def visit_Import(self, node): + def visit_Import(self, node: ast.Import) -> _T_visit_return: """Allow `import` statements with restrictions. See check_import_names.""" return self.check_import_names(node) - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node: ast.ImportFrom) -> _T_visit_return: """Allow `import from` statements with restrictions. See check_import_names.""" return self.check_import_names(node) - def visit_alias(self, node): + def visit_alias(self, node: ast.alias) -> _T_visit_return: """Allow `as` statements in import and import from statements.""" return self.node_contents_visit(node) # Control flow - def visit_If(self, node): + def visit_If(self, node: ast.If) -> _T_visit_return: """Allow `if` statements without restrictions.""" return self.node_contents_visit(node) - def visit_For(self, node): + def visit_For(self, node: ast.For) -> _T_visit_return: """Allow `for` statements with some restrictions.""" return self.guard_iter(node) - def visit_While(self, node): + def visit_While(self, node: ast.While) -> _T_visit_return: """Allow `while` statements.""" return self.node_contents_visit(node) - def visit_Break(self, node): + def visit_Break(self, node: ast.Break) -> _T_visit_return: """Allow `break` statements without restrictions.""" return self.node_contents_visit(node) - def visit_Continue(self, node): + def visit_Continue(self, node: ast.Continue) -> _T_visit_return: """Allow `continue` statements without restrictions.""" return self.node_contents_visit(node) - def visit_Try(self, node): + def visit_Try(self, node: ast.Try) -> _T_visit_return: """Allow `try` without restrictions.""" return self.node_contents_visit(node) - def visit_TryStar(self, node): - """Disallow `ExceptionGroup` due to a potential sandbox escape.""" + def visit_TryStar(self, node: ast.AST) -> _T_visit_return: + """Disallow `ExceptionGroup` due to a potential sandbox escape. + + TODO: Type Annotation for node when dropping support + for Python < 3.11 should be ast.TryStar. + """ self.not_allowed(node) - def visit_ExceptHandler(self, node): + def visit_ExceptHandler(self, node: ast.ExceptHandler) -> _T_visit_return: """Protect exception handlers.""" node = self.node_contents_visit(node) self.check_name(node, node.name) return node - def visit_With(self, node): + def visit_With(self, node: ast.With) -> _T_visit_return: """Protect tuple unpacking on with statements.""" node = self.node_contents_visit(node) @@ -1119,13 +1169,13 @@ def visit_With(self, node): return node - def visit_withitem(self, node): + def visit_withitem(self, node: ast.withitem) -> _T_visit_return: """Allow `with` statements (context managers) without restrictions.""" return self.node_contents_visit(node) # Function and class definitions - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node: ast.FunctionDef) -> _T_visit_return: """Allow function definitions (`def`) with some restrictions.""" self.check_name(node, node.name, allow_magic_methods=True) self.check_function_argument_names(node) @@ -1135,44 +1185,44 @@ def visit_FunctionDef(self, node): self.inject_print_collector(node) return node - def visit_Lambda(self, node): + def visit_Lambda(self, node: ast.Lambda) -> _T_visit_return: """Allow lambda with some restrictions.""" self.check_function_argument_names(node) return self.node_contents_visit(node) - def visit_arguments(self, node): + def visit_arguments(self, node: ast.arguments) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_arg(self, node): + def visit_arg(self, node: ast.arg) -> _T_visit_return: """ """ return self.node_contents_visit(node) - def visit_Return(self, node): + def visit_Return(self, node: ast.Return) -> _T_visit_return: """Allow `return` statements without restrictions.""" return self.node_contents_visit(node) - def visit_Yield(self, node): + def visit_Yield(self, node: ast.Yield) -> _T_visit_return: """Allow `yield`statements without restrictions.""" return self.node_contents_visit(node) - def visit_YieldFrom(self, node): + def visit_YieldFrom(self, node: ast.YieldFrom) -> _T_visit_return: """Allow `yield`statements without restrictions.""" return self.node_contents_visit(node) - def visit_Global(self, node): + def visit_Global(self, node: ast.Global) -> _T_visit_return: """Allow `global` statements without restrictions.""" return self.node_contents_visit(node) - def visit_Nonlocal(self, node): + def visit_Nonlocal(self, node: ast.Nonlocal) -> _T_visit_return: """Deny `nonlocal` statements.""" self.not_allowed(node) - def visit_ClassDef(self, node): + def visit_ClassDef(self, node: ast.ClassDef) -> _T_visit_return: """Check the name of a class definition.""" self.check_name(node, node.name) node = self.node_contents_visit(node) @@ -1183,13 +1233,14 @@ def visit_ClassDef(self, node): class {0.name}(metaclass=__metaclass__): pass '''.format(node)) - new_class_node = ast.parse(CLASS_DEF).body[0] + new_class_node = typing.cast( + ast.ClassDef, ast.parse(CLASS_DEF).body[0]) new_class_node.body = node.body new_class_node.bases = node.bases new_class_node.decorator_list = node.decorator_list return new_class_node - def visit_Module(self, node): + def visit_Module(self, node: ast.Module) -> _T_visit_return: """Add the print_collector (only if print is used) at the top.""" node = self.node_contents_visit(node) @@ -1207,25 +1258,26 @@ def visit_Module(self, node): # Async und await - def visit_AsyncFunctionDef(self, node): + def visit_AsyncFunctionDef( + self, node: ast.AsyncFunctionDef) -> _T_visit_return: """Deny async functions.""" self.not_allowed(node) - def visit_Await(self, node): + def visit_Await(self, node: ast.Await) -> _T_visit_return: """Deny async functionality.""" self.not_allowed(node) - def visit_AsyncFor(self, node): + def visit_AsyncFor(self, node: ast.AsyncFor) -> _T_visit_return: """Deny async functionality.""" self.not_allowed(node) - def visit_AsyncWith(self, node): + def visit_AsyncWith(self, node: ast.AsyncWith) -> _T_visit_return: """Deny async functionality.""" self.not_allowed(node) # Assignment expressions (walrus operator ``:=``) # New in 3.8 - def visit_NamedExpr(self, node): + def visit_NamedExpr(self, node: ast.NamedExpr) -> _T_visit_return: """Allow assignment expressions under some circumstances.""" # while the grammar requires ``node.target`` to be a ``Name`` # the abstract syntax is more permissive and allows an ``expr``. @@ -1237,7 +1289,7 @@ def visit_NamedExpr(self, node): node = self.node_contents_visit(node) # this checks ``node.target`` target = node.target if not isinstance(target, ast.Name): - self.error( + self.error( # type: ignore[unreachable] node, "Assignment expressions are only allowed for simple targets") return node